In [1]:
from dataset import XRayDatasetResnet
import numpy as np
import copy
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
import os
from sklearn.metrics import f1_score
from datetime import datetime
import matplotlib.pyplot as plt

print('Setup done!')

Setup done!


In [2]:
preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

In [3]:
#train eval per baseline
train_set = XRayDatasetResnet("/data/train/", preprocess)
#eval_set pel baseline
eval_set = XRayDatasetResnet("/data/eval/", preprocess)

root = os.getcwd()

In [5]:
model_ft = models.resnet101(pretrained=True)
num_ftrs = model_ft.fc.in_features

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            
for param in model_ft.parameters():
    param.requires_grad = False

model_ft.fc = nn.Linear(num_ftrs, 4)

params_to_update = model_ft.parameters()
print("Params to learn:")
params_to_update = []
for name,param in model_ft.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)
        
# Parameters
batch_size = 32
epochs = 25
learning_rate = 0.001
#early stop: diff
eps = 0.0001

loss_function = nn.NLLLoss()
optimizer = optim.Adam(params_to_update, lr = learning_rate, eps = eps)

Params to learn:
	 fc.weight
	 fc.bias


In [6]:
#for param in model_ft.parameters():
#    print(param)


In [7]:
def Train(num_epochs, network, loader_train, loader_eval): # afegir preprocess = unet

    epochs_stop = 5
    epochs_no_improve = 0
    stop = False
    # variable on guardem el millor model
    best_net = copy.deepcopy(network.state_dict())
    best_loss = 1.0
    best_acc = 0.0


    # evolucio dels parametres del train
    loss_train_evo = []
    acc_train_evo = []
    fs_train_evo = []


    # evolucio dels parametres del eval
    loss_eval_evo = []
    acc_eval_evo = []
    fs_eval_evo = []


    for epoch in range(num_epochs):

        print('Epoch [{}/{}]'.format(epoch + 1, num_epochs))

        for phase in ["train", "eval"]:

            # print("fase: ", phase)
            if phase == 'train':
                network.train()  # Set model to training mode
                loader = loader_train
            else:
                network.eval()  # Set model to evaluate mode
                loader = loader_eval


            #print(len(loader))
            running_loss = 0.0
            running_corrects = 0
            total_images = 0
            fscore = []


            for images, labels in loader:
                # now 'images' is a batch containing 32 samples
                # and 'targets' is a batch containing 32 targets (of the images in 'images' with the same index

                # clear gradients
                optimizer.zero_grad()

                labels = torch.as_tensor(labels)


                with torch.set_grad_enabled(phase == 'train'):

                    output = network(images)
                    
                    log_output = F.log_softmax(output, dim=1)
                    loss = loss_function(log_output, labels)
                    label_pred = torch.max(log_output, 1)[1]

                    if phase == 'train':
                        # backpropagation, compute gradients
                        loss.backward()
                        # apply gradients
                        optimizer.step()

                Fscore = f1_score(labels, label_pred, average='macro')
                fscore.append(Fscore)
                running_loss += loss.item()
                running_corrects += torch.sum(label_pred == labels)
                total_images += len(output)

            #print("total images: ", total_images)
            epoch_loss = running_loss / total_images  #vigilar aquests quocients
            epoch_acc = running_corrects.double() / total_images
            epoch_fscore = np.average(np.array(fscore))

            print('{}: Loss: {:.4f}, Ac: {:.4f}, fs: {:.4f}, rc: {}'
                  .format(phase, epoch_loss, epoch_acc, epoch_fscore, running_corrects))

            if phase == 'train':
                loss_train_evo.append(epoch_loss)
                #epoch_acc = epoch_acc.cpu().numpy()
                acc_train_evo.append(epoch_acc)
                fs_train_evo.append(epoch_fscore)
            else:
                loss_eval_evo.append(epoch_loss)
                #epoch_acc = epoch_acc.cpu().numpy()
                acc_eval_evo.append(epoch_acc)
                fs_eval_evo.append(epoch_fscore)

            if phase == 'eval' and (best_loss > epoch_loss):
                # print("best_loss: ", best_loss)
                # print("epoch_loss: ", epoch_loss)
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_net = copy.deepcopy(network.state_dict())
                epochs_no_improve = 0
                                
            elif phase == 'eval' and (best_loss < epoch_loss):
                epochs_no_improve += 1
                if(epochs_no_improve > epochs_stop):
                    print("Early stop!")
                    network.load_state_dict(best_net)
                    stop = True
                

            if stop == True:
                break


        if stop == True:
            break
    print("Best accuracy eval %.2f:" % best_acc)
    # load best model weights
    network.load_state_dict(best_net)
    
    return network, loss_train_evo, acc_train_evo, fs_train_evo, loss_eval_evo, acc_eval_evo, fs_eval_evo

In [8]:
train_dataloader = torch.utils.data.DataLoader(
            train_set,
            batch_size=batch_size,
            shuffle=True,
)


eval_dataloader = torch.utils.data.DataLoader(
            eval_set,
            batch_size=batch_size,
            shuffle=True,
)

In [9]:
start_time = datetime.now()
print("Training for baseline stars...")
best_net, loss_train, acc_train, fs_train, loss_val, acc_val, fs_val = Train(epochs, model_ft, train_dataloader,
                                                                             eval_dataloader)
print("Training complete")
finish_time = datetime.now()
time = finish_time - start_time
print("Time required to train: ", time)

torch.save(best_net, root + '/trained_models/resnet_baseline')

Training for baseline stars...
Epoch [1/25]


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


KeyboardInterrupt: 

In [None]:
def plot_metric(metric_train, metric_val, title):
    %matplotlib inline
    fig, (ax) = plt.subplots(1, 1)
    fig.suptitle(title)
    ax.set(xlabel='epoch')
    ax.plot(metric_train, label='Training')
    ax.plot(metric_val, label='Validation')
    ax.legend(loc='upper left')


In [None]:
plot_metric(loss_train, loss_val, 'Loss')


In [None]:
plot_metric(acc_train, acc_val, 'Accuracy')

In [None]:
plot_metric(fs_train, fs_val, 'F-Score')