# LeNet model

In [1]:
import pandas as pd
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from escnn import gspaces
from escnn import nn as enn

In [2]:
def build_mask(s, margin=2, dtype=torch.float32):
    mask = torch.zeros(1, 1, s, s, dtype=dtype)
    c = (s-1) / 2
    t = (c - margin/100.*c)**2
    sig = 2.
    for x in range(s):
        for y in range(s):
            r = (x - c) ** 2 + (y - c) ** 2
            if r > t:
                mask[..., x, y] = math.exp((t - r)/sig**2)
            else:
                mask[..., x, y] = 1.
    return mask

In [None]:
imgsize = 160 #pixels along one side
num_classes = 3
num_channels =3

LeNet structure

In [3]:
class VanillaLeNet(nn.Module):
    def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=None):
        super(VanillaLeNet, self).__init__()
        
        z = 0.5*(imsize - 2)
        z = int(0.5*(z - 2))
        
        self.mask = build_mask(imsize, margin=1)

        self.conv1 = nn.Conv2d(in_chan, 6, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(6, 16, kernel_size, padding=1)
        self.fc1   = nn.Linear(16*z*z, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, out_chan)
        self.drop  = nn.Dropout(p=0.5)
        
        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
        
    def loss(self,p,y):
        
        # check device for model:
        device = self.dummy.device
        
        # p : softmax(x)
        loss_fnc = nn.NLLLoss().to(device=device)
        loss = loss_fnc(torch.log(p),y)
        
        return loss
     
    def enable_dropout(self):
        for m in self.modules():
            if isinstance(m, nn.Dropout):
                m.train()

        return
        
    def forward(self, x):
        
        # check device for model:
        device = self.dummy.device
        mask = self.mask.to(device=device)
        
        x = x*mask
        
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = x.view(x.size()[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.drop(x)
        x = self.fc3(x)
    
        return x

In [4]:
class CNSteerableLeNet(nn.Module):
    def __init__(self, in_chan, out_chan, imsize, kernel_size=5, N=8):
        super(CNSteerableLeNet, self).__init__()
        
        z = 0.5*(imsize - 2)
        z = int(0.5*(z - 2))
        
        self.r2_act = gspaces.Rot2dOnR2(N)
        
        in_type = enn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
        self.input_type = in_type
        
        out_type = enn.FieldType(self.r2_act, 6*[self.r2_act.regular_repr])
        self.mask = enn.MaskModule(in_type, imsize, margin=1)
        self.conv1 = enn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False)
        self.relu1 = enn.ReLU(out_type, inplace=True)
        self.pool1 = enn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)

        in_type = self.pool1.out_type
        out_type = enn.FieldType(self.r2_act, 16*[self.r2_act.regular_repr])
        self.conv2 = enn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False)
        self.relu2 = enn.ReLU(out_type, inplace=True)
        self.pool2 = enn.PointwiseMaxPoolAntialiased(out_type, kernel_size=2)
        
        self.gpool = enn.GroupPooling(out_type)

        self.fc1   = nn.Linear(16*z*z, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, out_chan)
        
        self.drop  = nn.Dropout(p=0.5)
        
        # dummy parameter for tracking device
        self.dummy = nn.Parameter(torch.empty(0))
        
        
    def loss(self,p,y):
        
        # check device for model:
        device = self.dummy.device
        
        # p : softmax(x)
        loss_fnc = nn.NLLLoss().to(device=device)
        loss = loss_fnc(torch.log(p),y)
        
        return loss
     
    def enable_dropout(self):
        for m in self.modules():
            if isinstance(m, nn.Dropout):
                m.train()

        return
      
      
    def forward(self, x):
        
        x = enn.GeometricTensor(x, self.input_type)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = self.gpool(x)
        x = x.tensor
        
        x = x.view(x.size()[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.drop(x)
        x = self.fc3(x)
    
        return x

In [7]:
hii = VanillaLeNet(num_channels,num_classes,imgsize+1)

Running

In [None]:
#function to train the model on data
def train(model, trainloader, optimiser, device):
    train_loss = 0.0
    model.train()

    for batch_idx, (data, labels) in enumerate(trainloader):
        data, labels = data.to(device), labels.to(device)

        optimiser.zero_grad()

        #calculate train loss
        p_y = model(data)
        loss_criterion = nn.CrossEntropyLoss()
        labels = labels.type(torch.LongTensor) 
        labels = labels.to(device)
        loss = loss_criterion(p_y, labels)
            
        train_loss += loss.item() * data.size(0)

        #feed the loss back
        loss.backward()
        optimiser.step()

    train_loss /= len(trainloader.dataset)
    return train_loss

#function to test the model on data
def validate(model, testloader, device):
    prediction=[]
    target=[]
    py_ten=torch.empty((0,3), dtype=torch.float64)
    correct = 0
    test_loss = 0.0

    model.eval()
    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(testloader):
            data, labels = data.to(device), labels.to(device)

            p_y = model(data)
            # if METHOD:
            #     #tensor for probabilities, for ROC score
            #     py_ten = torch.cat((py_ten, p_y), 0)

            #calculate test loss
            loss_criterion = nn.CrossEntropyLoss()
            labels = labels.type(torch.LongTensor)
            labels = labels.to(device)
            loss = loss_criterion(p_y, labels)
                
            test_loss += loss.item() * data.size(0)
            
            #values for metrics
            preds = p_y.argmax(dim=1, keepdim=True)
            correct += preds.eq(labels.view_as(preds)).sum().item()
            prediction+=preds.squeeze(1).tolist()
            target+=labels.tolist()
        test_loss /= len(testloader.dataset)
        accuracy = correct / len(testloader.dataset)

        #calculate ROC score metric
        # if METHOD:
        #     scale_prob = nn.Softmax(dim=1)(py_ten).numpy()
        #     roc_auc = roc_auc_score(np.array(target), scale_prob, multi_class="ovr")
        #     recall = recall_score(target, prediction, average='weighted')
        #     f1 = f1_score(target, prediction, average='weighted')
        #     precision = precision_score(target, prediction, average='weighted')
        # else:
        roc_auc = roc_auc_score(target, prediction)
        recall = recall_score(target, prediction)
        f1 = f1_score(target, prediction)
        precision = precision_score(target, prediction)
    return test_loss, accuracy, roc_auc, prediction, target, recall, f1, precision

In [None]:
#function to run the entire model train/test loop
def run_func():
    #split the data into train, test and validation
    # X_train, X_val_and_test, Y_train, Y_val_and_test = train_test_split(X, y,test_size=0.3)
    # X_val, X_test, Y_val, Y_test = train_test_split(X_val_and_test, Y_val_and_test, test_size=0.5)

    # X_train_tensor = torch.from_numpy(X_train).float()
    # Y_train_tensor = torch.from_numpy(Y_train.astype("float64")).float()

    # X_val_tensor = torch.from_numpy(X_val).float()
    # Y_val_tensor = torch.from_numpy(Y_val.astype("float64")).float()

    # X_test_tensor = torch.from_numpy(X_test).float()
    # Y_test_tensor = torch.from_numpy(Y_test.astype("float64")).float()

    # trainset = TensorDataset(X_train_tensor, Y_train_tensor)
    # valset = TensorDataset(X_val_tensor, Y_val_tensor)
    # testset = TensorDataset(X_test_tensor, Y_test_tensor)

    # train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    # val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True)
    # test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

    # #create the model
    # model = CustomNN(input_size, hidden_size, output_size).to(device)

    # #optimizer and learning rate scheduler
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay = weight_decay)#0.005, #0.01
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=2, factor=0.9)

    # #array to store metrics
    # result_arr = np.zeros((num_epochs,8))

    # _bestloss = 1.
    # bestepoch = 0
    # #NN learning
    # for epoch in range(num_epochs):
    
    #     train_loss = train(model, train_loader, optimizer, device)
    #     val_loss, accuracy, roc_auc, prediction, target,  recall, f1, precision = validate(model, val_loader, device)
        
    #     scheduler.step(val_loss)

    #     if early_stopping and val_loss<_bestloss:
    #         _bestloss = val_loss
    #         torch.save(model.state_dict(), modfile)
    #         best_epoch = epoch
    #         cm = confusion_matrix(target, prediction,normalize = 'true')

    #     #set output row
    #     results = [epoch, train_loss, val_loss, accuracy, roc_auc,  recall, f1, precision]
    #     result_arr[epoch] = results

    #     #print epocch results
    #     if not QUIET:
    #         print('Epoch: {}, Validation Loss: {:4f}, Validation Accuracy: {:4f}'.format(epoch, val_loss, accuracy))
    #         print('Current learning rate is: {}'.format(optimizer.param_groups[0]['lr']))
    
    # test_arr=[]
    # if TEST_USE:
    #     #Test data run
    #     bestmodel=CustomNN(input_size, hidden_size, output_size).to(device)
    #     bestmodel.load_state_dict(torch.load('modfile.pt'))
    #     test_loss, test_accuracy, test_auc, prediction2,target2,  recall2, f12, precision2 = validate(bestmodel, test_loader, device)
    #     test_arr=[test_loss, test_accuracy, test_auc, prediction2,target2,  recall2, f12, precision2]

    # if not early_stopping:
    #     torch.save(model.state_dict(), modfile)
    #     best_epoch = -1
    #     cm = confusion_matrix(target, prediction,normalize = 'true')

    # return result_arr, cm, test_arr, best_epoch