In [2]:
import numpy as np 
import matplotlib.pyplot as plt 
import matplotlib.ticker as ticker
from mpl_toolkits.mplot3d import Axes3D

import torch 
import torch.nn as nn 
from sklearn.model_selection import train_test_split

import torchvision
import torchvision.transforms as transfroms

In [3]:
device = torch.device('mps')
print('device: ', device)

device:  mps


In [4]:
learning_rate = 0.001
batch_size = 100
epochs = 5

In [5]:
train_set = torchvision.datasets.MNIST(
    root = './data/MNIST',
    train = True,
    download = True,
    transform = transfroms.Compose([
        transfroms.ToTensor()
    ])
)
dt_set = torchvision.datasets.MNIST(
    root = './data/MNIST',
    train = False,
    download = True,
    transform = transfroms.Compose([
        transfroms.ToTensor()
    ])
)

test_set, val_set = torch.utils.data.random_split(dt_set, [int(len(dt_set) / 2)] * 2)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_set, batch_size=int(batch_size/2), shuffle=False, drop_last = False) 
test_loader = torch.utils.data.DataLoader(test_set, batch_size=int(batch_size/2))

In [6]:
class GenMNIST(nn.Module):     
    def __init__(self, inp_ = 2,  output_ = 1):         
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x):         
        out = self.layer(x)
        return out      

In [7]:
from statistics import mean    
import time

In [8]:
def train(model, criterion_, optimizer_, scheduler_, num_epochs=40, first_epoch=1):     
    train_losses = []     
    valid_losses = []     
    test_losses = []     
    print("----------------------------------------------------------------------------")     
    
    start_time = time.time()     
    for epoch in range(first_epoch, first_epoch + num_epochs):       
        model.train()     
        
        batch_losses = []         
        
        for samples, labels in train_loader:             
            samples = samples.to(device)                
            labels = labels.to(device)             
            
            optimizer_.zero_grad()             
            outputs = model(samples)             
            
            loss = criterion_.forward(outputs, labels)             
            batch_losses.append(loss.item())             
            
            loss.backward()             
            
            optimizer_.step()         
            
        train_losses.append(mean(batch_losses))       
        model.eval()
        
        with torch.no_grad():             
            correct_test = 0             
            for samples, labels in valid_loader:
                samples = samples.to(device)                 
                labels = labels.to(device)                 
                
                outputs = model(samples)                 
                
                loss = criterion_(outputs, labels)                 
                valid_losses.append(loss.item())             
            
            for samples, labels in test_loader:                 
                samples = samples.to(device)                 
                labels = labels.to(device)                 
                
                outputs = model(samples) 

                loss = criterion_(outputs, labels)                 
                test_losses.append(loss.item())         
        
        if (epoch) % 1000 == 0 :             
            curr_time = round(time.time()-start_time)             
            train_rec = round(train_losses[-1],5)             
            valid_rec = round(valid_losses[-1],5)             
            test_rec = round(test_losses[-1],5)             
            print('Epoch', epoch, ' / ',num_epochs)             
            print(f"\t [Train loss: {train_rec}]  [Validation loss: {valid_rec}] [Test loss: {test_rec}]  [curr LR = {scheduler_.get_last_lr()} ],[elapsed_time = {curr_time}sec]")         
            
        scheduler_.step()     
    
    print(f"\nTrain Ended, total_elapsed_time = {round(time.time()-start_time)} ")     
    print("--------------------------------------------------------------------")     
    
    return train_losses, valid_losses, test_losses 


In [9]:
model = GenMNIST().to(device) 
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr= learning_rate)    

train_losses, valid_losses, test_losses = train(model, criterion_ = criterion, optimizer_ = optimizer, num_epochs=int(len(train_set)/batch_size))

: 

: 