In [3]:
import torch
import torch.nn as nn 
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


## **Model Architecture**

In [2]:
# Simple 3 layer covolution model 
class CNN_model(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        
        # Convolution blocks
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channel, 16, kernel_size=3, stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(16, 32, kernel_size=3, stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        # NN block
        self.fully_connected = nn.Sequential(
            nn.Linear(4096, 256),       # for 3x64x64 image -> 4096 = 8x8x64 (change according to input image)
            nn.ReLU(inplace=True),
            nn.Linear(256, 9)
            
        )
    def forward(self,x):
        X = self.feature_extractor(x)
        X = torch.flatten(X,1)
        X = self.fully_connected(X)
        
        return X
        
        

## **Training**

In [None]:
def Train(device, model, epochs):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    
    EPOCHS = epochs
    train_loss, val_loss = [], []
    train_acc, val_acc = [],[]
    
    for epoch in range(EPOCHS):
        model.train()
        