In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataset import get_dataloaders
import matplotlib.pyplot as plt
from torchvision import models

In [2]:
train_loader,test_loader,classes = get_dataloaders(image_size=224,
                                                   num_channels=3,
                                                   path_to_data="Cyrillic")

In [3]:
num_classes = len(classes)

In [4]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [5]:
resnet = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = resnet.fc.in_features

resnet.fc = nn.Linear(num_ftrs,num_classes)
resnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(resnet.parameters(), lr=0.0001)

In [None]:
def train_model(model,criterion,optimizer,epochs=25):
    train_loss_history = []
    test_loss_history = []
    train_accuracy_history = []
    test_accuracy_history = []
    best_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        sum_loss = 0
        sum_true_preds = 0
        for data,target in train_loader:
            data,target = data.to(device),target.to(device)
            optimizer.zero_grad()
            train_preds = model(data)

            batch_loss = criterion(train_preds,target)
            true_preds = (train_preds.argmax(dim=1) == target).sum().item()

            sum_true_preds+=true_preds
            sum_loss+=batch_loss.item()    
            batch_loss.backward()
            optimizer.step()
        
        train_epoch_loss = sum_loss/len(train_loader)
        train_loss_history.append(train_epoch_loss)

        train_epoch_accuracy = sum_true_preds/len(train_loader.dataset)
        train_accuracy_history.append(train_epoch_accuracy)

        model.eval()
        sum_loss = 0
        sum_true_preds = 0
        test_epoch_loss = 0
        with torch.no_grad():
            for data,target in test_loader:
                data,target = data.to(device),target.to(device)
                preds = model(data)

                batch_loss = criterion(preds,target)
                true_preds = (preds.argmax(dim=1) == target).sum().item()

                sum_loss+=batch_loss.item()
                sum_true_preds+=true_preds

            test_epoch_loss= sum_loss/len(test_loader)
            test_loss_history.append(test_epoch_loss)

            test_epoch_accuracy = sum_true_preds/len(test_loader.dataset)
            test_accuracy_history.append(test_epoch_accuracy)

        print(f"epoch: {epoch+1}")
        print(f"train loss: {train_epoch_loss}, train accuracy: {train_epoch_accuracy}, test loss: {test_epoch_loss}, test accuracy: {test_epoch_accuracy}")


        if test_epoch_loss < best_loss:
            best_loss = test_epoch_loss
            torch.save(model.state_dict(), "weights/best_model_resnet.pth")

    return train_loss_history,test_loss_history,train_accuracy_history,test_accuracy_history


: 

In [None]:
train_loss, test_loss,train_accuracy,test_accuracy = train_model(resnet,criterion=criterion,optimizer=optimizer_ft)



epoch: 1
train loss: 1.726471225420634, train accuracy: 0.6090595765632694, test loss: 0.6871735801299413, test accuracy: 0.8595339678372169


In [None]:
def build_plot(train,test,label):
    plt.plot(train,label="train "+label)
    plt.plot(test,label="test "+label)
    plt.legend(loc="upper right")
    plt.title(label)
    plt.show()

In [None]:
build_plot(train_loss,test_loss,label='Loss')
build_plot(train_accuracy,test_accuracy,label='Accuracy')