# Training the Classification Models


In [7]:
from lib import dataLoader
from lib import model_eval
from lib import models
from torch.utils.tensorboard import SummaryWriter
import torch
import torchvision
import os
from inspect import getmembers, isclass

from torch.nn import CrossEntropyLoss
from torch.optim import Adam

First, initiate the variables common to all models. In this case they will all use the same training dataloader object and device.

In [2]:
#getting training dataloader and setting device to GPU
path = "C:/Users/Ryan/.cache/BreaKHis_split/train"
train = dataLoader.make_dataloader(path, train = True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Then, train each model in the model.py file.  
Each model training will need it's own optimizer, as Adam will take different model parameters for each one.  
For each model, training metrics can be seen in real time by using *tensorboard --logdirs=summary_writers/MODEL_NAME* in console, though epoch results will also be printed by the function.

In [9]:
#loss function is the same for all
loss_fxn = CrossEntropyLoss()

#get list of models
model_list = getmembers(models, isclass)
print(model_list)

[('DenseNet_model', <class 'lib.models.DenseNet_model'>), ('MobileNet_model', <class 'lib.models.MobileNet_model'>), ('ResNet_model', <class 'lib.models.ResNet_model'>), ('VGG_model', <class 'lib.models.VGG_model'>)]


It's best to train each model seperately, instead of all at once.  
That way, if one model fails for some reason, not all of them need to be retrained.  
(I may or may not have learned this the hard way.)

In [15]:
def run_train_fxn(m):
    writer = SummaryWriter(log_dir="summary_writers/" + m[0])
    model = m[1]()
    optimizer = Adam(model.parameters())
    model_eval.train_model(model, m[0], train, writer, loss_fxn, optimizer, device, epochs = 25)

### Training DenseNet model

In [17]:
run_train_fxn(model_list[0])

Training DenseNet_model:
Epoch: 1
Training Accuracy: 0.6497838171710932
Training Precision: 0.6893203883495146
Training Recall: 0.8930817610062893
Training F1: 0.7780821917808219


Epoch: 2
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 3
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 4
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 5
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 6
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 7
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Trainin

### Training MobileNet model

In [19]:
run_train_fxn(model_list[1])

Training MobileNet_model:
Epoch: 1
Training Accuracy: 0.6757257566399012
Training Precision: 0.6870229007633588
Training Recall: 0.9703504043126685
Training F1: 0.8044692737430168


Epoch: 2
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 3
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 4
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 5
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 6
Training Accuracy: 0.6886967263743051
Training Precision: 0.6911487758945386
Training Recall: 0.9892183288409704
Training F1: 0.8137472283813747


Epoch: 7
Training Accuracy: 0.8684373069796171
Training Precision: 0.8926701570680629
Training Rec

### Training ResNet model

In [20]:
run_train_fxn(model_list[2])

Training ResNet_model:
Epoch: 1
Training Accuracy: 0.6695491043854231
Training Precision: 0.6960651289009498
Training Recall: 0.921832884097035
Training F1: 0.793196752995748


Epoch: 2
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 3
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 4
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 5
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 6
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 7
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1

### Training VGG model

In [21]:
run_train_fxn(model_list[3])

Training VGG_model:
Epoch: 1
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 2
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 3
Training Accuracy: 0.6874613959234095
Training Precision: 0.6874613959234095
Training Recall: 1.0
Training F1: 0.8147877013177159


Epoch: 4
Training Accuracy: 0.7801111797405806
Training Precision: 0.7843726521412472
Training Recall: 0.9380053908355795
Training F1: 0.8543371522094927


Epoch: 5
Training Accuracy: 0.8165534280420013
Training Precision: 0.8469387755102041
Training Recall: 0.894878706199461
Training F1: 0.8702490170380078


Epoch: 6
Training Accuracy: 0.8214947498455837
Training Precision: 0.8497453310696095
Training Recall: 0.89937106918239
Training F1: 0.8738542121344391


Epoch: 7
Training Accuracy: 0.8363187152563311
Training Precision: 0.8509933774834437
Traini