In [41]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from typing import Tuple
import boto3 as boto
import random
import _pickle as cPickle
from torch.utils.data.distributed import DistributedSampler
import os
import torchvision.transforms as T 
from torchvision.transforms import v2
from torchvision.io import read_image

In [42]:
a = read_image("patient2.png")

In [43]:
a.size()

torch.Size([4, 620, 488])

In [44]:
40089312/302560

132.50037017451083

In [45]:
a.sum(axis=(1,2))

tensor([40089312, 34832137, 28956982, 76746330])

In [46]:
(a.size()[1] * a.size()[2])

302560

In [47]:
a.sum(axis=(1,2))/(a.size()[1] * a.size()[2])

tensor([132.5004, 115.1247,  95.7066, 253.6566])

In [48]:
torch.mean(a.float(),dim = (1,2))

tensor([132.5004, 115.1247,  95.7066, 253.6565])

In [49]:
torch.std(a.float(),dim = (1,2))

tensor([77.3145, 65.6562, 64.2510, 18.4600])

In [50]:
#SINGLE GPU INSTANCE CAPABILITIES ONLY. WILL HAVE TO UPDATE FOR MULTIGPU
class PRPDataSet(Dataset):
    
    def __init__(self, patient_ids: list,patient_labels: dict, num_classification_output: int,path: str, size: tuple):
        
        #dictionary of where patient ids are keys and rating is value
        self.patient_labels = patient_labels
        
        #list of patient ids
        self.patient_ids = patient_ids
        
        #path
        self.path = path
        
        self.num_classification_output = num_classification_output
        
        self.resize = v2.Resize(size)
        self.to_image = v2.ToImage()
        self.to_Dtype = v2.ToDtype(torch.float32, scale=True)
        
        
        
    def __len__(self):
        return len(self.patient_ids)
    
    def __getitem__(self,idx:int):
        
        #indexing via list
        patient = self.patient_ids[idx]
        
        #pulling image from storage bucket and un-pickling it
        #THIS MAY HAVE TO BE CHANGED DEPENDING ON HOW EXACTLY THE IMAGES ARE STORED
        #image = cPickle.load(open(f'{self.path}/{patient}/image'),"rb")
        image = read_image(f'{patient}.png')
        
        #calculating mean and standard deviation along channels for one photo
        image = self.resize(image)
        image = self.to_image(image)
        image = self.to_Dtype(image)
        mean,std = self.mean_std(image)
        norm = v2.Normalize(mean=mean,std = std)
        image = norm(image)
        
        #transformation used on each image 
        #Note that mean/std are vectors of len(4) for the alpha,RBG channels
        #self.transforms = v2.Compose([
            #v2.Resize(size),
            #v2.ToImage(), 
            #v2.ToDtype(torch.float32, scale=True),
            #v2.Normalize(mean = mean,std = std)
        #])
        
        #Resize image
        #resized_image = self.transforms(image)
        #print(f"IMAGE SUM: {torch.sum(resized_image)}")
        
        #Get label
        label_tensor = [0 for _ in range(self.num_classification_output)]
        label = torch.tensor(self.patient_labels[patient])
        label_tensor[label-1] = 1
        label_tensor = torch.tensor(label_tensor)
        
        return image, label_tensor
    
    #Should maybe try to calculate mean_std deviation once prior to running script and store data
    #So that we do not have to rerun this step everytime the script is rerun 
    #
    #Note we are calculating the mean of pixel average for each the image
    #And the mean of the pixel std. dev for each the images
    def mean_std(self,image):
        
        a = torch.mean(image,axis=(1,2))
        b = torch.mean(image,axis=(1,2))
        
        return a,b
    
def sequential_train_test_split(split: tuple, labels:dict):
    
    #Getting number of patients 
    num_patients = len(labels.keys())
    #Getting how patients to include in training
    training_num = int(split[0] * num_patients)
    #creating training set
    training_patients = labels.keys()[:training_num]
    #Creating Test set 
    testing_patients = labels.keys()[training_num:]

    return training_patients, testing_patients

def random_train_test_split(split: tuple, labels:dict):
    
    #Converting label list to label set to use subtraction
    patient_set = set(labels.keys())
    
    #Traning set
    training_patients = set()

    #calculating number of patients
    num_patients = len(labels.keys())
    
    #calculating number of patients in training set
    training_num = int(split[0] * num_patients)
    
    
    i = 0
    #looping number of patients 
    while i != training_num:
        #Choosing a patient from the list of left over patients not yet chosen
        curr_patient = random.choice(list(patient_set-training_patients))
        #adding patient to list of training patient
        training_patients.add(curr_patient)
        #Iterating
        i += 1

    #Getting test set
    test_patients = patient_set - training_patients
    #print(list(training_patients))
    #print(list(test_patients))

    #Reconverting back to list 
    return list(training_patients), list(test_patients)

def get_train_test_dataset(dict_path: str, data_path: str, num_class_output: int,size: tuple ,sequential: bool, split:tuple):
    
    #Getting dictionary that has patient keys and scores as values
    #Path may need to be changed
    patient_labels = cPickle.load(open(dict_path, 'rb'))
    
    #If want the data to be split sequentially (i.e in order)
    if sequential:
        train, test = sequential_train_test_split(split,patient_labels)
    else:
    #If want data split randomly (More Likely used)
        train, test = random_train_test_split(split,patient_labels)
    
    
    #Creating PRPDataSet objects for both train and test sets
    train_set = PRPDataSet(train,patient_labels,num_class_output,data_path,size)
    test_set = PRPDataSet(test,patient_labels,num_class_output,data_path,size)
    
    return train_set,test_set

def PRPDataLoader(dict_path: str, data_path: str, num_class_output: int,size: tuple,sequential: bool, split: tuple, batch: int):
    
    #Getting PRPDataSet objects for both rain and test sets 
    train_set, test_set = get_train_test_dataset(dict_path, data_path, num_class_output, size ,sequential, split)
    #print(f"LEN TRAIN: {len(train_set)}")
    #Creating DataLoader with Train and Set Data sets
    train_generator = DataLoader(train_set, batch_size = batch, shuffle = True)
    test_generator = DataLoader(test_set, batch_size = batch, shuffle = True)
    
    #print(f"TRAIN GEN LEN: {len(train_generator)}")
    #print(f"TEST GEN LEN: {len(test_generator)}")
    
    return train_generator,test_generator

In [56]:
#Not set up for multi-gpu running, need to add things to get that ready 
#Also would like to add confusion matrix output
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import pandas as pd
import _pickle as cPickle
import random
import itertools
import time
from datetime import datetime, timedelta
from matplotlib import pyplot as plt
import boto3 as boto
import torch.multiprocessing as mp



#class Trainer:
    
    #def __init___(self, 
                  #model : torch.nn.Module,
                  #optimizer: torch.optim.Optimizer,
                  #loss_fn: torch.nn,
                  #save_interval: int, 
                  #metric_interval: int,
                  #train_data: DataLoader,
                  #validation_data: DataLoader = None, 
                  #test_data: DataLoader = None,
                  #save_path: str = None): 
        
        #Setting all variables equal to a local counterpart 
        #self.model = model
        #self.optimizer = optimizer
        #self.loss_fn = loss_fn
        #self.save_interval = save_interval
        #self.metric_interval = metric_interval
        #self.train_data = train_data
        #self.validation_data = validation_data
        #self.test_data = test_data
        #self.save_path = save_path
        
        #going to be used in evaluating function to decrease latency of model 
        #self.curr_predictions = []
        #self.curr_labels = []
        
class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 loss_fn: torch.nn.Module,
                 save_interval: int,
                 metric_interval: int,
                 train_data: DataLoader,
                 validation_data: DataLoader = None,
                 test_data: DataLoader = None,
                 save_path: str = None):
        
        # Setting all variables equal to a local counterpart
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.save_interval = save_interval
        self.metric_interval = metric_interval
        self.train_data = train_data
        self.validation_data = validation_data
        self.test_data = test_data
        self.save_path = save_path

        # Going to be used in evaluating function to decrease latency of model
        self.curr_predictions = []
        self.curr_labels = []
        
    def _run_batch(self,batch: torch.Tensor, batch_labels: torch.Tensor):
        #Setting current gradient to zero for new batch
        self.optimizer.zero_grad()
        #Running model on current batch
        #print("running _run_batch")
        pred_output = self.model(batch)
        #print("Done running batch")
        #Appending predicted values to list for evaluation 
        self.curr_predictions.append(pred_output)
        #Appending label values to list for evaluation
        self.curr_labels.append(batch_labels)
        
        #Computing loss 
        #print(f"PRED OUTPUT: {pred_output}")
        #print(f"BATCH LABEL: {batch_labels}")
        loss = self.loss_fn(pred_output.float(),batch_labels.float())
        
        #Computing gradient for each parameter in model
        loss.backward()
        #grads = [p.grad for p in self.model.parameters()]
        #print(grads[:50])
        #Gradient descent 
        self.optimizer.step()
        
    def _run_epoch(self, epoch:int):
        #Including epoch num in init bc likely will want to print out at somepoint to see how quickly model runs
        #Setting model to train
        
        self.model.train()
        #Re-initiating prediction/label accumulator for evaluation of specific epoch
        
        self.curr_predictions = []
        self.curr_labels = []
        
        #Looping over each training batch
        #print("training model")
        for batch_tensor, batch_labels in self.train_data:
            
            #print(f"batch_tensor: {batch_tensor.shape}")
            #print(f"batch_labels: {batch_labels.shape}")
            #print(batch_tensor)
            
            #Running gradient descent on each batch
            self._run_batch(batch_tensor,batch_labels)
        #print("done training model")
        #print("exiting _run_epoch")
    
    #TODO: FINISH how to save to server
    def _save_checkpoint(self, epoch: int):
        #Getting the model weights at particular checkpoint
        #print("in save_checkpoint method")
        checkpoint_model = self.model.state_dict()
        #print("after getting the state-dict")
        #Pickling model into checkpoint_{epoch} file
        #Note that pickle.dump saves model in local directory
        #Need to delete after dump and upload
        torch.save(checkpoint_model,f'checkpoint_{epoch}.pt')
        #cPickle.dump(checkpoint_model, open(f'checkpoint_{epoch}.pt', 'wb'))
        #print("after pickle dump")
        #NEED TO FINISH SAVING TO FOLDER OF DICTIONARIES
        
    
    def train(self, num_epochs:int):
        
        #Looping over number of epochs
        for epoch in range(1,num_epochs+1):
            
            #Running an epoch
            print(f"running {epoch} epoch")
            self._run_epoch(epoch)
            
            #print("outside self.save_interval")
            #Code to save the model every save_interval   
            if self.save_interval > 0 and epoch % self.save_interval == 0:
                #print("In save_interval")
                self._save_checkpoint(epoch)
            #Saving the last model
            elif epoch == num_epochs:
                self._save_checkpoint(epoch)

            #Evaluating model every metric_interval
            #print("outside self.metric_interval")
            if self.metric_interval > 0 and epoch % self.metric_interval == 0:
                #Decreases time bc saved inferences for training data in list
                #Evaluating Training set
                self._evaluate(None)
                #Will have to do inference for test set
                if self.test_data != None:
                    #Evaluating test set
                    self._evaluate(self.test_data)
                    #Resetting the model to train 
                    self.model.train()
                
    def _evaluate(self, dataloader: DataLoader = None):
        #Converting to torch.no_grad to prevent gradient calculation
        with torch.no_grad():
            #Set model to evaluation
            self.model.eval()
            #If dataloader none, it means we are looking at the current training set accuracy 
            if dataloader == None:
                #Using already predicted values that we accumulated during the training 
                #This obviously is a lower bound on the accuracy of our model on the training set
                #However, transformer latency is extremely large so this will decrease training time overall
                #Also we don't care about the actual training accuracy, we only care about the overall trends of 
                #training accuracy
                predict_output = torch.vstack(self.curr_predictions)
                labels = torch.vstack(self.curr_labels)
                print("/tTRAINING SET VALUES")
            else:
                #Creating accumulators for test set
                test_predict = []
                test_labels = []
                #Looping over each tensor and label in the dataloader
                for batch_tensor, batch_label in dataloader:
                    #Predicting using model on test set
                    #print("IN THE ELSE STATEMENT TO TEST TEST SET")
                    prediction = self.model(batch_tensor)
                    #accumulating model predictions and labels of test set
                    test_predict.append(prediction)
                    test_labels.append(batch_label)
                    #print(f"TEST PREDICT len:{len(test_predict)}")
                #Vstacking outputs and labels so that tensors read (patient x 1)
                #Note loss function is MSE, so output from model will be a singular value that relates to our 
                #actual scale
                #This differs from CrossEntropyLoss, where model output would be vector of length (num_classes)
                #and each entry would be a probability of particular class
                predict_output = torch.vstack(test_predict)
                labels = torch.vstack(test_labels)
                print("/tTEST SET VALUES")
            
            #Squeezing output to get rid of nested tensors
            predict_output = torch.squeeze(predict_output)
            labels = torch.squeeze(labels)
            
            predict_output = torch.argmax(predict_output,axis = 1)
            labels = torch.argmax(labels,axis = 1)
            #Calculating loss of the model for train/test set
            loss = self.loss_fn(predict_output.float(), labels.float())
            #Calculating Mean Absolute Error based on train/test set
            #print(f"PREDICT_OUTPUT: {predict_output}")
            #print(f"LABELS: {labels}")
            MAE = (predict_output.float() - labels.float()).abs().mean().item()
            
            #Rounding predicted output so that it matches the exact categories given by Norwood scale
            #predict_output = torch.round(predict_output)
            
            #Calculating how many predictions were correct
            num_correct = (predict_output == labels).sum().item()
            
            #Calculating accuracy of model
            acc = num_correct / len(labels)
            
            print(f"\t\t NUMBER CORRECT: {num_correct}")
            print(f"\t\t ACCURACY: {acc}")
            print(f"\t\t MEAN ABSOLUTE ERROR: {MAE}")
            print(f"\t\t LOSS: {loss}")
            print(f"\t\t PREDICTED: {predict_output}")
            print(f"\t\t LABELS: {labels}")
            print(f"++++++++++++++++++++++++++++++++++++++++++++++++++++")                     

In [52]:
import torch
from torch import nn
import numpy as np
import pandas as pd
import os
import pydicom as dicom
import math
import time

class PositionalEncoding(nn.Module):
        
    def __init__(self, data_shape, dropout = .1):
        super(PositionalEncoding,self).__init__()
        
        #Get data shape
        #self.in_channels, self.row_len, self.col_len = data_shape
        self.row_len, self.col_len = data_shape
        
        self.learned_embedding = torch.zeros(data_shape)
        self.learned_embedding = nn.Parameter(self.learned_embedding[None,:,:])
                                              
        
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self,data):
        
        data = data + self.learned_embedding
        data = self.dropout(data)
        
        return data

class Convlayer(nn.Module):
    
    def __init__(self,data_shape:tuple,num_patches:int,output_dim:int = None):
        super(Convlayer,self).__init__()
        self.num_patches = num_patches
        
        self.batch,self.in_channels, self.row_len, self.col_len = data_shape
        
        assert self.row_len % num_patches == 0 
        assert self.col_len % num_patches == 0
        
        
        self.patch_row = self.row_len // num_patches
        self.patch_col = self.col_len // num_patches
        
        self.embed_dim = int(self.in_channels * self.patch_row * self.patch_col)
        
        self.kernel_len = (int(self.patch_row), int(self.patch_col))
        
        patch_area = int(self.row_len * self.col_len)
        
        self.conv2d_1 = nn.Conv2d(in_channels = self.in_channels, 
                                  out_channels = self.embed_dim, 
                                  kernel_size =self. kernel_len, 
                                  stride = self.kernel_len)
        
        
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten(start_dim=2) 
        if output_dim == None:
            self.dnn = nn.Linear(self.embed_dim,self.embed_dim)
        else:
            self.output_dim = output_dim
            self.dnn = nn.Linear(self.embed_dim,output_dim)
    def forward(self,data):
        
        
        #x = self.conv2d_1(data)
        #print(x.shape)
        #x = self.relu(x)
        #x = self.flatten(x)
        #print(x.shape)
        #x = torch.transpose(x,1,2)
        #print(f'DATA SHAPE: {data.shape}')
        batch = data.shape[0]
        patches = data.unfold(2,self.row_len,self.row_len).unfold(3,self.col_len,self.col_len)
        patches = torch.reshape(patches,(batch,self.num_patches**2,self.embed_dim))
        
        #print(x.shape)
        x = self.dnn(patches)
        
        return x
    
class MultiHeadAttention(nn.Module):
    
    def __init__(self,data_shape,num_heads):
        
        super(MultiHeadAttention,self).__init__()
        #self.batch, self.patch, self.embed = data_shape
        self.patch, self.embed = data_shape
        self.attn = nn.MultiheadAttention(self.embed,num_heads,batch_first = True)
    def forward(self,data):
        #data = torch.reshape(data,(data.shape[1],data.shape[2],data.shape[3]))
        outputs , _ = self.attn(query=data, key=data, value=data, need_weights = False)
        return outputs

class MLP(nn.Module):
    def __init__(self,data_shape,output_size,dropout = .1):
        super(MLP,self).__init__()
        #self.batch, self.patch, self.embed = data_shape
        self.patch, self.embed = data_shape
        hidden_output = self.embed * 2
        self.lnn1 = nn.Linear(self.embed, hidden_output)
        self.dropout1 = nn.Dropout(dropout)
        self.fnn2 = nn.Linear(hidden_output, output_size)
        self.dropout2 = nn.Dropout(dropout)
        self.gelu = nn.GELU()
    def forward(self,data):
        
        x = self.lnn1(data)
        x = self.gelu(x)
        x = self.dropout1(x)
        x = self.fnn2(x)
        x = self.gelu(x)
        x = self.dropout2(x)
        
        return x
        
class TransformerEncoder(nn.Module):
    
    def __init__(self,data_shape,num_heads,dropout=.1):
        super(TransformerEncoder,self).__init__()
        self.data_shape = data_shape
        self.patch, self.embed = data_shape
        #self.batch, self.patch, self.embed = data_shape
        self.ln1 = nn.LayerNorm([self.patch,self.embed])
        self.ln2 = nn.LayerNorm([self.patch,self.embed])
        self.MHA = MultiHeadAttention(data_shape,num_heads)
        self.mlp = MLP(data_shape, output_size = self.embed, dropout=dropout)
        
    def forward(self,data):
        
        x = self.ln1(data)
        att_out = self.MHA(x)
        att_out = att_out + data
        after_ln2 = self.ln2(att_out)
        after_ln2 = self.mlp(after_ln2)
        after_ln2 = after_ln2 + att_out
        
        return after_ln2
        
class VisionTransformer(nn.Module):
    def __init__(self,data_shape,num_heads,num_layers = 6,dropout = .1):
        super(VisionTransformer,self).__init__()
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(
                f'{i}', TransformerEncoder(data_shape=data_shape,num_heads = num_heads,dropout = dropout))
    
    def forward(self,data):
        x = data
        for blk in self.blks:
            x = blk(x)
        return x
    
class ClassificationHead(nn.Module):
    def __init__(self,
                 input_layer,
                 hidden_layer_1,
                 hidden_layer_2,
                 hidden_layer_3,
                 hidden_layer_4,
                 hidden_layer_5,
                 num_output,
                 dropout=.1):
        super(ClassificationHead,self).__init__()
        self.ln1 = nn.LayerNorm(input_layer)
        self.fnn1 = nn.Linear(input_layer,hidden_layer_1)
        self.dropout_1 = nn.Dropout(dropout)
        self.ln2 = nn.LayerNorm(hidden_layer_1)
        self.fnn2 = nn.Linear(hidden_layer_1,hidden_layer_2)
        self.dropout_2 = nn.Dropout(dropout)
        self.ln3 = nn.LayerNorm(hidden_layer_2)
        self.fnn3 = nn.Linear(hidden_layer_2,hidden_layer_3)
        self.dropout_3 = nn.Dropout(dropout)
        self.ln4 = nn.LayerNorm(hidden_layer_3)
        self.fnn4 = nn.Linear(hidden_layer_3,hidden_layer_4)
        self.dropout_4 = nn.Dropout(dropout)
        self.ln5 = nn.LayerNorm(hidden_layer_4)
        self.fnn5 = nn.Linear(hidden_layer_4,hidden_layer_5)
        self.dropout_5 = nn.Dropout(dropout)
        self.fnn6 = nn.Linear(hidden_layer_5,num_output)
        
    def forward(self,data):
        x = self.ln1(data)
        x = self.fnn1(x)
        x = self.dropout_1(x)
        x = self.ln2(x)
        x = self.fnn2(x)
        x = self.dropout_2(x)
        x = self.ln3(x)
        x = self.fnn3(x)
        x = self.dropout_3(x)
        x = self.ln4(x)
        x = self.fnn4(x)
        x = self.dropout_4(x)
        x = self.ln5(x)
        x = self.fnn5(x)
        x = self.dropout_5(x)
        x = self.fnn6(x)
        
        return x
    
class PRPModel(nn.Module):
    def __init__(self,
                 data_shape,
                 num_patch:int,
                 num_heads,
                 num_output,
                 num_layers,
                 conv_output_dim,
                 hidden_layer_1,
                 hidden_layer_2,
                 hidden_layer_3,
                 hidden_layer_4,
                 hidden_layer_5,
                dropout = .1):
        super(PRPModel,self).__init__()
        self.batch_size = data_shape[0]
        self.data_shape = data_shape[1:]
        
        self.conv_layer = Convlayer(data_shape,num_patch,conv_output_dim)
        self.in_channels, self.row_len, self.col_len = self.data_shape
        assert self.row_len % num_patch == 0 
        assert self.col_len % num_patch == 0
        
        patch_row = self.row_len // num_patch
        patch_col = self.col_len // num_patch
        
        embed_dim = patch_row * patch_col * self.in_channels
        assert embed_dim % num_heads == 0, f"embed_dimension is not divisible by num_heads \nembed_dim: {embed_dim},heads:{num_heads}"
        
        #self.data_shape = (self.batch_size,num_patch**2,patch_row * patch_col * self.in_channels)
        self.data_shape = (num_patch**2,patch_row * patch_col * self.in_channels)
        
        self.pos_encode = PositionalEncoding(self.data_shape,dropout)
        
        
        self.visual_transformer = VisionTransformer(self.data_shape,num_heads,num_layers,dropout)
        
        #self.input_layer = self.data_shape[1] * self.data_shape[2]
        #self.input_layer = self.data_shape[0] * self.data_shape[1]
        self.input_layer = self.data_shape[0] * self.data_shape[1]
        
        #self.ClassificationHead = ClassificationHead(self.input_layer,
                                                     #hidden_layer_1,
                                                     #hidden_layer_2,
                                                     #hidden_layer_3,
                                                     #num_output,
                                                     #dropout=.1)
                            
        self.ClassificationHead = ClassificationHead(self.data_shape[1],
                                                     hidden_layer_1,
                                                     hidden_layer_2,
                                                     hidden_layer_3,
                                                     hidden_layer_4,
                                                     hidden_layer_5,
                                                     num_output,
                                                     dropout=.1)
        self.softmax = nn.Softmax(dim = 1)
        
    def forward(self,data):
        batch_size = data.shape[0]
        
        x = self.conv_layer(data)
        x = self.pos_encode(x)
        x = self.visual_transformer(x)
        #print(x.shape)
        #x = torch.reshape(x,(batch_size,self.input_layer))
        x = torch.squeeze(x[:,-1,:])
        x = self.ClassificationHead(x)
        #print(f"BEFORE SOFTMAX: {x}")
        x = self.softmax(x)
        #x = torch.argmax(x,axis = 1)
        return x

In [53]:
torch.zeros(5,1600)[-1,:].size()

torch.Size([1600])

In [54]:
def main(dict_path: str,
        data_path:str,
        batch_size: int ,
        sequential : bool , 
        split: tuple , 
        image_size: tuple,
        num_patch: int,
        num_heads: int,
        vision_num_layers :int,
        conv_output_dim:int,
        hidden_layer_1:int,
        hidden_layer_2:int,
        hidden_layer_3:int,
        hidden_layer_4:int,
        hidden_layer_5:int,
        num_classification_output : int,
        dropout: float,
        save_interval:int,
        metric_interval:int,
        save_path:str,
        num_epochs:int
               ):
    train_data , test_data = PRPDataLoader(dict_path,data_path,num_classification_output,image_size,sequential,split,batch_size)
    in_channels, row_len, col_len = 4, image_size[0], image_size[1]
    
    assert row_len % num_patch == 0, 'row_len not divisible by num_patches' 
    assert col_len % num_patch == 0, "col_len not divisible by num_patches"
        
    patch_row = row_len // num_patch
    patch_col = col_len // num_patch
    embed_dim = patch_row * patch_col * in_channels
    assert embed_dim % num_heads == 0, "embed_dimension is not divisible by num_heads"

    model = PRPModel((batch_size,4,image_size[0],image_size[1]),
                     num_patch,
                     num_heads,
                     num_classification_output,
                     vision_num_layers,
                     conv_output_dim,
                     hidden_layer_1,
                     hidden_layer_2,
                     hidden_layer_3,
                     hidden_layer_4,
                     hidden_layer_5,
                     dropout)
    
    adam_optimizer = torch.optim.Adam(
        model.parameters(), lr=0.001, weight_decay=0.0001)
    
    mse_loss = torch.nn.MSELoss()

    trainer = Trainer(model,adam_optimizer,mse_loss,save_interval,metric_interval,train_data,test_data= test_data,save_path = save_path)

    trainer.train(num_epochs=num_epochs)

In [57]:
main(dict_path = "test_dict",
     data_path = "",
     batch_size = 5,
     sequential = False,
     split = (.8,.2),
     image_size = (300,300),
     num_patch = 30,
     num_heads = 5,
     vision_num_layers = 8,
     conv_output_dim = None,
     hidden_layer_1 = 1600,
     hidden_layer_2 = 800,
     hidden_layer_3 = 400,
     hidden_layer_4 = 100,
     hidden_layer_5 = 25,
     num_classification_output = 7,
     dropout = .1,
     save_interval = 1000000,
     metric_interval = 1,
     save_path = "",
     num_epochs = 100)

running 1 epoch
/tTRAINING SET VALUES
		 NUMBER CORRECT: 2
		 ACCURACY: 0.25
		 MEAN ABSOLUTE ERROR: 2.0
		 LOSS: 7.75
		 PREDICTED: tensor([5, 6, 6, 1, 6, 0, 0, 6])
		 LABELS: tensor([1, 5, 6, 3, 0, 1, 0, 4])
++++++++++++++++++++++++++++++++++++++++++++++++++++
/tTEST SET VALUES
		 NUMBER CORRECT: 0
		 ACCURACY: 0.0
		 MEAN ABSOLUTE ERROR: 1.0
		 LOSS: 1.0
		 PREDICTED: tensor([1, 1])
		 LABELS: tensor([2, 2])
++++++++++++++++++++++++++++++++++++++++++++++++++++
running 2 epoch
/tTRAINING SET VALUES
		 NUMBER CORRECT: 1
		 ACCURACY: 0.125
		 MEAN ABSOLUTE ERROR: 2.125
		 LOSS: 6.625
		 PREDICTED: tensor([1, 1, 1, 1, 1, 2, 4, 4])
		 LABELS: tensor([6, 1, 4, 0, 3, 0, 1, 5])
++++++++++++++++++++++++++++++++++++++++++++++++++++
/tTEST SET VALUES
		 NUMBER CORRECT: 0
		 ACCURACY: 0.0
		 MEAN ABSOLUTE ERROR: 1.0
		 LOSS: 1.0
		 PREDICTED: tensor([1, 1])
		 LABELS: tensor([2, 2])
++++++++++++++++++++++++++++++++++++++++++++++++++++
running 3 epoch
/tTRAINING SET VALUES
		 NUMBER CORRECT: 2
	

/tTEST SET VALUES
		 NUMBER CORRECT: 0
		 ACCURACY: 0.0
		 MEAN ABSOLUTE ERROR: 1.5
		 LOSS: 2.5
		 PREDICTED: tensor([0, 3])
		 LABELS: tensor([2, 2])
++++++++++++++++++++++++++++++++++++++++++++++++++++
running 19 epoch
/tTRAINING SET VALUES
		 NUMBER CORRECT: 4
		 ACCURACY: 0.5
		 MEAN ABSOLUTE ERROR: 0.875
		 LOSS: 1.625
		 PREDICTED: tensor([3, 3, 4, 3, 3, 0, 0, 5])
		 LABELS: tensor([1, 3, 4, 1, 5, 0, 0, 6])
++++++++++++++++++++++++++++++++++++++++++++++++++++
/tTEST SET VALUES
		 NUMBER CORRECT: 0
		 ACCURACY: 0.0
		 MEAN ABSOLUTE ERROR: 2.5
		 LOSS: 6.5
		 PREDICTED: tensor([5, 0])
		 LABELS: tensor([2, 2])
++++++++++++++++++++++++++++++++++++++++++++++++++++
running 20 epoch
/tTRAINING SET VALUES
		 NUMBER CORRECT: 4
		 ACCURACY: 0.5
		 MEAN ABSOLUTE ERROR: 1.0
		 LOSS: 2.75
		 PREDICTED: tensor([5, 0, 3, 3, 5, 3, 0, 5])
		 LABELS: tensor([6, 0, 1, 4, 5, 3, 0, 1])
++++++++++++++++++++++++++++++++++++++++++++++++++++
/tTEST SET VALUES
		 NUMBER CORRECT: 0
		 ACCURACY: 0.0
		 ME

/tTRAINING SET VALUES
		 NUMBER CORRECT: 8
		 ACCURACY: 1.0
		 MEAN ABSOLUTE ERROR: 0.0
		 LOSS: 0.0
		 PREDICTED: tensor([1, 3, 6, 0, 1, 5, 4, 0])
		 LABELS: tensor([1, 3, 6, 0, 1, 5, 4, 0])
++++++++++++++++++++++++++++++++++++++++++++++++++++
/tTEST SET VALUES
		 NUMBER CORRECT: 0
		 ACCURACY: 0.0
		 MEAN ABSOLUTE ERROR: 1.5
		 LOSS: 2.5
		 PREDICTED: tensor([1, 0])
		 LABELS: tensor([2, 2])
++++++++++++++++++++++++++++++++++++++++++++++++++++
running 37 epoch
/tTRAINING SET VALUES
		 NUMBER CORRECT: 7
		 ACCURACY: 0.875
		 MEAN ABSOLUTE ERROR: 0.125
		 LOSS: 0.125
		 PREDICTED: tensor([5, 0, 1, 5, 3, 0, 1, 4])
		 LABELS: tensor([6, 0, 1, 5, 3, 0, 1, 4])
++++++++++++++++++++++++++++++++++++++++++++++++++++
/tTEST SET VALUES
		 NUMBER CORRECT: 0
		 ACCURACY: 0.0
		 MEAN ABSOLUTE ERROR: 3.0
		 LOSS: 10.0
		 PREDICTED: tensor([0, 6])
		 LABELS: tensor([2, 2])
++++++++++++++++++++++++++++++++++++++++++++++++++++
running 38 epoch
/tTRAINING SET VALUES
		 NUMBER CORRECT: 6
		 ACCURACY: 0.7

KeyboardInterrupt: 

In [18]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

NameError: name 'model' is not defined

In [11]:
test_dict = {"vertex1" : 1, "vertex2" : 2, "vertex3": 3,"vertex4":4,
             "vertex5":5,"vertex6":6,"vertex7":7,"vertex8":1,"vertex9":2,"vertex10":3}
cPickle.dump(test_dict, open(f'test_dict', 'wb'))

In [282]:
x = torch.zeros(2,4,200,200)
patches = x.unfold(2, 20, 20)
print(patches.shape)
patches = patches.unfold(3, 20, 20)

patches = torch.reshape(patches,(2,100,4 * 20 * 20))

torch.Size([2, 4, 10, 200, 20])


In [266]:
patches.shape

torch.Size([100, 1600])

In [101]:
practice.unfold(1,20,20).shape

torch.Size([4, 10, 200, 20])