In [28]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
input_size = 784
epochs = 2
learning_rate = 0.01
batch_size = 100

In [30]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)) # gray channel to rgb          
])
train_datasets = torchvision.datasets.MNIST(root='./datasets/', train=True, transform=transform)
test_datasets = torchvision.datasets.MNIST(root='./datasets/', train=False, transform=transform)

train_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True,)
test_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)

In [31]:
model = torchvision.models.resnet18(pretrained=True)
# freeze params
for param in model.parameters():
    param.requires_grad = False

# change 3 channel to single channel
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)


In [32]:
# training loop
total_steps = len(train_loader)
for epoch in range(epochs):
    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # backward
        optimizer.zero_grad()
        loss.backward()
        
        # update weight
        optimizer.step()
        
        
        if (idx+1) % 100 == 0:
            print(f"epoch {epoch + 1} / {epochs}, step {idx+1} / {total_steps}, loss = {loss:.4f}")
            

epoch 1 / 2, step 100 / 600, loss = 0.9667
epoch 1 / 2, step 200 / 600, loss = 0.6079
epoch 1 / 2, step 300 / 600, loss = 0.9426
epoch 1 / 2, step 400 / 600, loss = 0.3553
epoch 1 / 2, step 500 / 600, loss = 0.6085
epoch 1 / 2, step 600 / 600, loss = 0.5233
epoch 2 / 2, step 100 / 600, loss = 0.6308
epoch 2 / 2, step 200 / 600, loss = 0.7595
epoch 2 / 2, step 300 / 600, loss = 0.7638
epoch 2 / 2, step 400 / 600, loss = 0.6811
epoch 2 / 2, step 500 / 600, loss = 0.5391
epoch 2 / 2, step 600 / 600, loss = 0.5764


In [33]:
with torch.no_grad():
    n_corrects = 0
    n_samples = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        
        _, predictions = torch.max(outputs, 1)
        n_samples += labels.shape[0] 
        n_corrects += (predictions == labels).sum().item()
    
    acc = n_corrects / n_samples * 100.0
    print(f'accuracy = {acc:.2f}')

accuracy = 86.78
