In [None]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from datetime import datetime
import numpy as np
import time
from torchsummary import summary

In [None]:
%load_ext tensorboard

In [None]:
def train_step(model, loss_fn, opt, loader, calc_accuracy):
    loss_per_batches = 0
    correct = 0
    elapsed = 0
    start_epoch2 = time.time()
    for i, data in enumerate(loader):
        start_epoch = time.time()
        
        features, labels = data
        
        opt.zero_grad()
        
        y_pred = model(features)
        
        loss = loss_fn(y_pred, labels)
        loss.backward()
        
        opt.step()
        
        loss_per_batches += loss
        correct += calc_accuracy(y_pred.numpy(force=True), labels.numpy()) / len(y_pred)
        
        end_epoch = time.time()
        elapsed += (end_epoch - start_epoch)

        
    print(elapsed)
    print(time.time() - start_epoch2)
    return loss_per_batches/(i+1), correct/(i+1)

In [None]:
def train(model, loss_fn, opt, train_loader, val_loader, save_treshold=10, epochs=10, model_name='model_name'):
    
    def calc_accuracy(y, labels):
        #print(np.argmax(y, axis=-1))
        #print(labels)
        #print(np.argmax(y, axis=-1) == labels)
        return (np.argmax(y, axis=-1) == labels).sum()
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    writer = SummaryWriter('runs/' + model_name + '_{}'.format(timestamp))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')
    
    for epoch in range(epochs):
        start_epoch = time.time()
        print('EPOCH {}:'.format(epoch + 1))
        
        model.train()
        avg_loss, avg_acc = train_step(model, loss_fn, opt, train_loader, calc_accuracy)
        model.eval()
        
        vloss = 0
        correct = 0
        with torch.inference_mode():
            for i, vdata in enumerate(val_loader):
                vfeatures, vlabels = vdata
                vy_pred = model(vfeatures)
                vloss += loss_fn(vy_pred, vlabels)

                correct += calc_accuracy(vy_pred.numpy(), vlabels.numpy()) / len(vy_pred)
        avg_vloss = vloss / (i + 1)
        avg_vacc = correct / (i + 1)
        
        scheduler.step(avg_vloss)
        
        print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
        print('ACC train {} valid {}'.format(avg_acc, avg_vacc))
        
        writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch + 1)
        
        writer.add_scalars('Training vs. Validation Acc',
                    { 'Training' : avg_acc, 'Validation' : avg_vacc },
                    epoch + 1)
        
        if (epoch + 1) % save_treshold == 0:
            model_path = model_name +'_{}_{}'.format(timestamp, epoch)
            torch.save(model.state_dict(), model_path)
        end_epoch = time.time()
        elapsed = end_epoch - start_epoch
        print("Time per epoch {}s".format(elapsed))

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True, num_workers=4)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=20000, shuffle=False, num_workers=4)

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

In [None]:
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.smax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.smax(x)
        return x


model = GarmentClassifier()

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
train(model, loss_fn, optimizer, training_loader, validation_loader, epochs=1000, model_name=model.__class__.__name__)

In [None]:
%tensorboard --logdir runs/GarmentClassifier_20230417_160545

In [None]:
summary(model, input_size=(1,28,28))

In [None]:
for i, data in enumerate(training_loader):
    print(i, data)