In [75]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms

In [76]:
def add_to_class(Class):
    def wrapper(obj):
        setattr(Class,obj.__name__,obj)
    return wrapper

### Load Data

In [77]:
class FashionMNIST():
    def __init__(self, batch_size = 64, resize = (28,28), num_workers = 8):
        # super.__init__()
        trans = transforms.Compose([transforms.Resize(resize),transforms.ToTensor()])
        self.train = torchvision.datasets.FashionMNIST(root="./data", train = True, download=True, transform = trans)
        self.val = torchvision.datasets.FashionMNIST(root="./data", train = False, download =True, transform = trans)
        self.batch_size = batch_size
        self.num_workers = num_workers

In [78]:
@add_to_class(FashionMNIST)
def get_dataloader(self,train):
    X = self.train if train == True else self.val
    return torch.utils.data.DataLoader(X, batch_size = self.batch_size, num_workers = self.num_workers, shuffle=train)

In [79]:
data = FashionMNIST()

In [80]:
train_dataloader = data.get_dataloader(train=True)
X,y = next(iter(train_dataloader))
X.shape, X.dtype, y.shape, y.dtype

(torch.Size([64, 1, 28, 28]), torch.float32, torch.Size([64]), torch.int64)

### Model

In [110]:
class Model(nn.Module):
    def __init__(self,num_outputs=10,lr=0.01):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28,10),
        )
        self.lr = lr

In [111]:
@add_to_class(Model)
def forward(self,X):
    logits = self.net(X)
    return F.softmax(logits, dim=1)

In [134]:
loss_fn = nn.CrossEntropyLoss()

### Training Step

In [144]:
model = Model()
data = FashionMNIST()
optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.1)

max_epochs = 10
train_dataloader = data.get_dataloader(train=True)
val_dataloader = data.get_dataloader(train=False)

train_loss = 0
test_loss = 0 

for epoch in range(0,max_epochs):

    for X,y in train_dataloader:

        preds = model(X)
        loss = loss_fn(preds,y).mean()
        train_loss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_dataloader)

    for X,y in val_dataloader:

        with torch.inference_mode():
            y_pred = model(X)
        test_loss += loss_fn(y_pred, y)

    test_loss /= len(val_dataloader)

    print(f"epoch-{epoch+1} : training loss -{ train_loss:.4f}")
    print(f"epoch-{epoch+1} : test loss - {test_loss:.4f}")
    

epoch-1 : training loss -1.8637
epoch-1 : test loss - 1.7640
epoch-2 : training loss -1.7347
epoch-2 : test loss - 1.7345
epoch-3 : training loss -1.7103
epoch-3 : test loss - 1.7181
epoch-4 : training loss -1.6985
epoch-4 : test loss - 1.7092
epoch-5 : training loss -1.6910
epoch-5 : test loss - 1.7035
epoch-6 : training loss -1.6858
epoch-6 : test loss - 1.6993
epoch-7 : training loss -1.6817
epoch-7 : test loss - 1.6955
epoch-8 : training loss -1.6784
epoch-8 : test loss - 1.6927
epoch-9 : training loss -1.6756
epoch-9 : test loss - 1.6905
epoch-10 : training loss -1.6733
epoch-10 : test loss - 1.6888


In [143]:
test_batch = data.get_dataloader(train=False)
X,y = next(iter(test_batch))

preds = model(X[15])
output = preds.argmax(dim=-1)
print(output, y[15])

tensor([1]) tensor(1)
