In [2]:
%run -i cnn_utils.py
%run -i data_setup.py
%run -i restnet.py

In [3]:
train_loader, test_loader, classes = load_data(batch_size=32, dataset='signs_dataset')

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
from torch import optim
from torchmetrics import Accuracy

In [12]:
torch.manual_seed(42)
no_blocks = [3, 4, 6, 3]

model = ResNet(no_blocks, classes=len(classes)).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
acc_fn = Accuracy(task="multiclass", num_classes=len(classes)).to(device)

for epoch in range(10):
  loss_train, acc_train = 0., 0.
  loss_test, acc_test = 0., 0.

  model.train()
  for x, y in train_loader:
    x, y = x.to(device), y.to(torch.long).to(device)
    optimizer.zero_grad()

    output = model(x)
    loss = loss_fn(output, y.squeeze())
    acc_train += acc_fn(output, y.squeeze())
    loss.backward()
    optimizer.step()
    loss_train += loss.item()

  model.eval()
  with torch.inference_mode():
    for x, y in test_loader:
      x, y = x.to(device), y.to(torch.long).to(device)
      test_output = model(x)
      test_loss = loss_fn(test_output, y.squeeze())
      loss_test += test_loss.item()
      acc_test += acc_fn(test_output, y.squeeze())
    loss_test /=len(test_loader)
    acc_test /=len(test_loader)

      
  loss_train /= len(train_loader)
  acc_train /= len(train_loader)

  #loss_test = loss_test/len(test_loader)
  print(f"Train: {loss_train:.3f} | {acc_train:.3f} || Test: {loss_test:.3} | {acc_test:.3f}")
  #print(f"{loss_train:.3f}")



Train: 1.656 | 0.311 || Test: 16.6 | 0.166
Train: 0.713 | 0.721 || Test: 0.898 | 0.664
Train: 0.388 | 0.862 || Test: 0.159 | 0.972
Train: 0.172 | 0.941 || Test: 0.0703 | 0.983
Train: 0.138 | 0.951 || Test: 0.109 | 0.959
Train: 0.094 | 0.968 || Test: 0.0457 | 0.985
Train: 0.079 | 0.970 || Test: 0.0223 | 0.995
Train: 0.063 | 0.975 || Test: 0.0347 | 0.989
Train: 0.080 | 0.970 || Test: 0.0241 | 0.994
Train: 0.036 | 0.993 || Test: 0.0205 | 0.992
