conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch

install resnest using github url: pip install git+https://github.com/zhanghang1989/ResNeSt

or using pypi: pip install resnest --pre

In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Sampler
from PIL import Image, ImageOps
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from resnest.torch import resnest50


import time
import pickle
import numpy as np
from torchvision.transforms import Lambda
import argparse
import copy
import random
import numbers
from torch.utils.tensorboard import SummaryWriter
from sklearn import metrics
import os

In [14]:
if torch.cuda.is_available()==True:
    device="cuda:2"
else:
    device ="cpu"
    
print(torch.cuda.is_available())

True


# 1. Models: M1, M2, M3

In [3]:
# Model 1  RCNet 
class M1_resnet_lstm(torch.nn.Module):
    def __init__(self):
        super(M1_resnet_lstm, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.share = torch.nn.Sequential()    # self.cnn = self.share
        self.share.add_module("conv1", resnet.conv1)
        self.share.add_module("bn1", resnet.bn1)
        self.share.add_module("relu", resnet.relu)
        self.share.add_module("maxpool", resnet.maxpool)
        self.share.add_module("layer1", resnet.layer1)
        self.share.add_module("layer2", resnet.layer2)
        self.share.add_module("layer3", resnet.layer3)
        self.share.add_module("layer4", resnet.layer4)
        self.share.add_module("avgpool", resnet.avgpool)
        self.dropout = nn.Dropout(p=0.2)
        self.lstm = nn.LSTM(2048, 512, batch_first=True) # feature : 512
        self.fc = nn.Linear(512, 19)      # 512 feature -> 19 classes


        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])
        init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        x = x.view(-1, 3, 216,216)  # 384 216 x.view(-1, 3, 224, 224) 
        x = self.share.forward(x)   # output [batchsize, 2048,1,1]
        x = x.view(-1, sequence_length, 2048) 
        
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)  # 512
        y = y.contiguous().view(-1, 512) # feature 
        y = self.dropout(y)
        y = self.fc(y)  # predict
        return y

In [9]:
# M2: densenet + lstm
class M2_densenet_lstm(torch.nn.Module):
    def __init__(self):
        super(M2_densenet_lstm, self).__init__()
        resnet = models.densenet169(pretrained=True) #pretrained=True
        self.share = torch.nn.Sequential()
        self.share.add_module("features", resnet.features)
        #self.share.add_module("avgpool", resnet.avgpool)
        self.avg = nn.AvgPool2d(6)

        # self.share.add_module("classifier", resnet.classifier)
        #self.fc_1 = nn.Linear(9216, 4096)
        
        # 
        self.lstm = nn.LSTM(1664, 512, batch_first=True)
#        self.lstm = nn.LSTM(2028, 512, batch_first=True)
        self.fc = nn.Linear(512, 19)

        self.dropout = nn.Dropout(p=0.2)

        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])
        init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        x = x.view(-1, 3, 216, 216)
        x = self.share.forward(x) # ([100, 1664, 6, 6])   # ([100,2048,1,1])
        x = self.avg(x)
        x = x.view(-1, sequence_length, 1664)  
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)
        y = y.contiguous().view(-1, 512)
        y = self.dropout(y)
        y = self.fc(y)
        return y


In [6]:
# Model 3  ResNeSt + lstm
class M3_resnest_lstm(torch.nn.Module):
    def __init__(self):
        super(M3_resnest_lstm, self).__init__()
        resnet = resnest50(pretrained=True)
        self.share = torch.nn.Sequential()    # self.cnn = self.share
        self.share.add_module("conv1", resnet.conv1)
        self.share.add_module("bn1", resnet.bn1)
        self.share.add_module("relu", resnet.relu)
        self.share.add_module("maxpool", resnet.maxpool)
        self.share.add_module("layer1", resnet.layer1)
        self.share.add_module("layer2", resnet.layer2)
        self.share.add_module("layer3", resnet.layer3)
        self.share.add_module("layer4", resnet.layer4)
        self.share.add_module("avgpool", resnet.avgpool)
        self.dropout = nn.Dropout(p=0.2)
        self.lstm = nn.LSTM(2048, 512, batch_first=True) # feature : 512
        self.fc = nn.Linear(512, 19)      # 512 feature -> 19 classes


        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])
        init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        x = x.view(-1, 3, 216,216)  # 384 216 x.view(-1, 3, 224, 224) 
        x = self.share.forward(x)   # output [batchsize, 2048,1,1]
        x = x.view(-1, sequence_length, 2048) 
        
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)  # 512
        y = y.contiguous().view(-1, 512) # feature 
        y = self.dropout(y)
        y = self.fc(y)  # predict
        return y

In [7]:
M1 = M1_resnet_lstm()

In [10]:
M2 = M2_densenet_lstm()

In [13]:
M3 = M3_resnest_lstm()

# 2. Loss function

In [15]:
class WeightedCrossEntropy(torch.nn.Module):
    '''
    WCE
    '''       
     # 6-25
    def __init__(self, weight=torch.Tensor([0.0033, 0.4182, 0.1321, 0.0234, 0.0344, 0.0146, 0.0428, 0.0140, 0.0092,
        0.0272, 0.0096, 0.0323, 0.0341, 0.0508, 0.0151, 0.0160, 0.0365, 0.0738,
        0.0128])):
        super(WeightedCrossEntropy, self).__init__()
        
        weight = weight.to(device)
        self.weighted_cross_entropy = nn.CrossEntropyLoss(weight=weight)
        
    def forward(self, inputs, target):
        return self.weighted_cross_entropy.forward(inputs, target)
    

# 3. Dataset

In [60]:
def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

class CataractsDataset(Dataset):
    def __init__(self, file_paths,file_labels, transform=None,loader=pil_loader):
        self.file_paths = file_paths
        self.file_labels_phase = file_labels[:,0]
        self.transform = transform
        self.loader = loader
        
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
        img_names = self.file_paths[index]
        labels = self.file_labels_phase[index]
        imgs = self.loader(img_names)
        if self.transform is not None:
            imgs = self.transform(imgs)

        return imgs, labels, index

In [44]:
def get_dataset(data_path):
    with open(data_path, 'rb') as f:
        train_test_paths_labels = pickle.load(f)
    train_paths_50 = train_test_paths_labels[0]
    val_paths_50 = train_test_paths_labels[1]
    train_labels_50 = train_test_paths_labels[2]
    val_labels_50 = train_test_paths_labels[3]
    train_num_each_50 = train_test_paths_labels[4]
    val_num_each_50 = train_test_paths_labels[5]

    print('train_paths_20  : {:6d}'.format(len(train_paths_50)))
    print('train_labels_20 : {:6d}'.format(len(train_labels_50)))
    print('valid_paths_5  : {:6d}'.format(len(val_paths_50)))
    print('valid_labels_5 : {:6d}'.format(len(val_labels_50)))

    # train_labels_19 = np.asarray(train_labels_19, dtype=np.int64) yilin comment
    train_labels_50 = np.asarray(train_labels_50, dtype=np.int64)
    val_labels_50 = np.asarray(val_labels_50, dtype=np.int64)
    
    train_transforms = transforms.Compose([
            transforms.CenterCrop(216),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(5),
            transforms.ToTensor()
        ])
    
    test_transforms = transforms.Compose([
            transforms.CenterCrop(216),
            transforms.ToTensor(),

        ])

    train_dataset_50 = CataractsDataset(train_paths_50, train_labels_50, train_transforms)
    val_dataset_50 = CataractsDataset(val_paths_50, val_labels_50, test_transforms)

    return train_dataset_50, train_num_each_50,val_dataset_50, val_num_each_50

In [45]:
train_dataset_50, train_num_each_50, \
val_dataset_50, val_num_each_50 = get_dataset('../../gen_datasets/train_val_paths_labels.pkl')

train_paths_20  :  14160
train_labels_20 :  14160
valid_paths_5  :   2323
valid_labels_5 :   2323


In [46]:
train_dataset_50

<__main__.CataractsDataset at 0x7f9230252c40>

In [47]:
# the frame number of each video in training set
train_num_each_50

[461,
 625,
 915,
 532,
 852,
 475,
 719,
 467,
 684,
 458,
 607,
 414,
 465,
 2368,
 564,
 515,
 688,
 372,
 619,
 1360]

In [48]:
# the frame number of each video in val set
val_num_each_50

[479, 394, 583, 442, 425]

# 4. Training

In [55]:
class SequenceSampler(Sampler):
    def __init__(self, data_source, idx):
        super().__init__(data_source)
        self.data_source = data_source
        self.idx = idx

    def __iter__(self):
        return iter(self.idx)

    def __len__(self):
        return len(self.idx)


In [63]:
# for sliding window
def get_start_idx(sequence_length, list_each_length):
    count = 0
    idx = []
    for i in range(len(list_each_length)):
        for j in range(count, count + (list_each_length[i] + 1 - sequence_length)):
            idx.append(j)
        count += list_each_length[i]
    return idx


In [64]:
def train_model(model,train_dataset, train_num_each, val_dataset, val_num_each):
    # TensorBoard
    writer = SummaryWriter(tensorboard_path)

    # choose start index for sequence 
    train_useful_start_idx = get_start_idx(sequence_length, train_num_each)
    val_useful_start_idx = get_start_idx(sequence_length, val_num_each)

    num_train_we_use = len(train_useful_start_idx)
    num_val_we_use = len(val_useful_start_idx)

    
    train_idx = []
    for i in range(num_train_we_use):
        for j in range(sequence_length):
            train_idx.append(train_useful_start_idx[i] + j)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_useful_start_idx[i] + j)

    num_train_all = len(train_idx)
    num_val_all = len(val_idx)

    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('num of all train use: {:6d}'.format(num_train_all))
    print('num of all valid use: {:6d}'.format(num_val_all))

    val_loader = DataLoader(
        val_dataset,
        batch_size=val_batch_size,
        sampler=SeqSampler(val_dataset, val_idx),
        num_workers=workers,
        pin_memory=False
    )
    
    #####################################
    #model = resnet_lstm()
    # model.load_state_dict(torch.load(pretrained_model_path))
    #####################################   
        

    model.to(device)

    criterion_phase = WeightedCrossEntropy() #nn.CrossEntropyLoss(size_average=False)
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    
    best_model_wts = copy.deepcopy(model.state_dict()) 
    
    best_val_accuracy_phase = 0.0
    correspond_train_acc_phase = 0.0
    best_epoch = 0

    for epoch in range(epochs):
        torch.cuda.empty_cache()
        
        np.random.shuffle(train_useful_start_idx)
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_useful_start_idx[i] + j)

        train_loader = DataLoader(
            train_dataset,
            batch_size=train_batch_size,
            sampler=SequenceSampler(train_dataset, train_idx),
            num_workers=workers,
            pin_memory=False
        )

        # in training mode.
        model.train()
        train_loss_phase = 0.0
        train_corrects_phase = 0
        batch_progress = 0.0
        running_loss_phase = 0.0
        minibatch_correct_phase = 0.0
        train_start_time = time.time()
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()

            inputs, labels_phase = data[0].to(device), data[1].to(device)

            labels_phase = labels_phase[(sequence_length - 1)::sequence_length]

            inputs = inputs.view(-1, sequence_length, 3, 216,216) #224, 224)
            outputs_phase = model.forward(inputs)
            outputs_phase = outputs_phase[sequence_length - 1::sequence_length]

            _, preds_phase = torch.max(outputs_phase.data, 1)
            loss_phase = criterion_phase(outputs_phase, labels_phase)

            loss = loss_phase
            loss.backward()
            optimizer.step()

            running_loss_phase += loss_phase.data.item()
            train_loss_phase += loss_phase.data.item()

            batch_corrects_phase = torch.sum(preds_phase == labels_phase.data)
            train_corrects_phase += batch_corrects_phase


            if (i+1)*train_batch_size >= num_train_all:               
                running_loss_phase = 0.0
                minibatch_correct_phase = 0.0

            batch_progress += 1
            if batch_progress*train_batch_size >= num_train_all:
                percent = 100.0
                print('Train progress: %s [%d/%d]' % (str(percent) + '%', num_train_all, num_train_all), end='\n')
            else:
                percent = round(batch_progress*train_batch_size / num_train_all * 100, 2)
                print('Train progress: %s [%d/%d]' % (str(percent) + '%', batch_progress*train_batch_size, num_train_all), end='\r')

        train_elapsed_time = time.time() - train_start_time
        train_accuracy_phase = float(train_corrects_phase) / float(num_train_all) * sequence_length
        train_average_loss_phase = train_loss_phase / num_train_all * sequence_length

        
        writer.add_scalar('train acc epoch phase',
                          float(train_accuracy_phase),epoch)
        writer.add_scalar('train loss epoch phase',
                          float(train_average_loss_phase),epoch)
        
        
        
        
        #  in evaluation mode.
        model.eval()
        val_loss_phase = 0.0
        val_corrects_phase = 0
        val_start_time = time.time()
        val_progress = 0
        val_all_preds_phase = []
        val_all_labels_phase = []

        with torch.no_grad():
            for data in val_loader:

                inputs, labels_phase = data[0].to(device), data[1].to(device)


                labels_phase = labels_phase[(sequence_length - 1)::sequence_length]

                inputs = inputs.view(-1, sequence_length, 3, 216,216) # 224 224
                outputs_phase = model.forward(inputs)
                outputs_phase = outputs_phase[sequence_length - 1::sequence_length]

                _, preds_phase = torch.max(outputs_phase.data, 1)
                loss_phase = criterion_phase(outputs_phase, labels_phase)

                val_loss_phase += loss_phase.data.item()

                val_corrects_phase += torch.sum(preds_phase == labels_phase.data)


                for i in range(len(preds_phase)):
                    val_all_preds_phase.append(int(preds_phase.data.cpu()[i]))
                for i in range(len(labels_phase)):
                    val_all_labels_phase.append(int(labels_phase.data.cpu()[i]))


                val_progress += 1
                if val_progress*val_batch_size >= num_val_all:
                    percent = 100.0
                    print('Val progress: %s [%d/%d]' % (str(percent) + '%', num_val_all, num_val_all), end='\n')
                else:
                    percent = round(val_progress*val_batch_size / num_val_all * 100, 2)
                    print('Val progress: %s [%d/%d]' % (str(percent) + '%', val_progress*val_batch_size, num_val_all), end='\r')

        val_elapsed_time = time.time() - val_start_time
        val_accuracy_phase = float(val_corrects_phase) / float(num_val_we_use)
        val_average_loss_phase = val_loss_phase / num_val_we_use


        writer.add_scalar('validation acc epoch phase',
                          float(val_accuracy_phase),epoch)
        writer.add_scalar('validation loss epoch phase',
                          float(val_average_loss_phase),epoch)

        print('epoch: {:4d}'
              ' train in: {:2.0f}m{:2.0f}s'
              ' train loss(phase): {:4.4f}'
              ' train accu(phase): {:.4f}'
              ' valid in: {:2.0f}m{:2.0f}s'
              ' valid loss(phase): {:4.4f}'
              ' valid accu(phase): {:.4f}'
              .format(epoch,
                      train_elapsed_time // 60,
                      train_elapsed_time % 60,
                      train_average_loss_phase,
                      train_accuracy_phase,
                      val_elapsed_time // 60,
                      val_elapsed_time % 60,
                      val_average_loss_phase,
                      val_accuracy_phase))


        # choose the best model by accuracy
        if val_accuracy_phase > best_val_accuracy_phase:
            best_val_accuracy_phase = val_accuracy_phase
            correspond_train_acc_phase = train_accuracy_phase
            #copy the best model
            best_model_wts = copy.deepcopy(model.state_dict())
            best_epoch = epoch
        if val_accuracy_phase == best_val_accuracy_phase:
            if train_accuracy_phase > correspond_train_acc_phase:
                correspond_train_acc_phase = train_accuracy_phase
                best_model_wts = copy.deepcopy(model.state_dict())
                best_epoch = epoch

        save_val_phase = int("{:4.0f}".format(best_val_accuracy_phase * 10000))
        save_train_phase = int("{:4.0f}".format(correspond_train_acc_phase * 10000))
        base_name = "lstm" \
                     + "_epoch_" + str(best_epoch) \
                     + "_length_" + str(sequence_length) \
                     + "_batch_" + str(train_batch_size) \
                     + "_train_" + str(save_train_phase) \
                     + "_val_" + str(save_val_phase)
        
        #model_save_path = 'sl10_flip1_lr5e-5/'
        
        torch.save(best_model_wts, "/media/yilin/catarcnet/best_model/"+model_save_path+base_name+".pth")
        print("best_epoch",str(best_epoch))
        # model.module.state_dict()
        
        torch.save(model.state_dict(), "/media/yilin/catarcnet/temp/"+model_save_path+ "latest_model_"+str(epoch)+".pth")
              

    return "Complete"

In [30]:
# train configuration

sequence_length = 10
train_batch_size = 100
val_batch_size = 100
epochs = 1
workers = 2
learning_rate = 5e-5
MODEL = M1_resnet_lstm() # M2_densenet_lstm() /  M3_resnest_lstm()

# M1 = M1_resnet_lstm()
# M2 = M2_densenet_lstm()
# M3 = M3_resnest_lstm()


##################################
device = "cuda:1"  
model_save_path = 'resnet/sl10_lr5e-5_6-25train/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path

In [38]:
# get dataset 
train_dataset_50, train_num_each_50, \
val_dataset_50, val_num_each_50 = get_data('../../gen_datasets/train_val_paths_labels.pkl')
    

train_paths_20  :  14160
train_labels_20 :  14160
valid_paths_5  :   2323
valid_labels_5 :   2323


In [39]:
sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 1 
workers = 2   
learning_rate = 5e-5
MODEL = M1_resnet_lstm() # M2_densenet_lstm() /  M3_resnest_lstm()



#########################################
device = "cuda:1"  
model_save_path = 'resnet/sl10_lr5e-5_6-25train/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path
###########################################


# train M1
train_model(MODEL,(train_dataset_50),(train_num_each_50),(val_dataset_50),(val_num_each_50))

num train start idx :  13980
num of all train use: 139800
num of all valid use:  22780
Train progress: 100.0% [139800/139800]
Val progress: 100.0% [22780/22780]
epoch:    0 train in: 13m35s train loss(phase): 0.0661 train accu(phase): 0.6985 valid in:  0m40s valid loss(phase): 0.0883 valid accu(phase): 0.6475
best_epoch 0


In [42]:
sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 1 
workers = 2   
learning_rate = 5e-5
MODEL = M2_densenet_lstm() # M2_densenet_lstm() /  M3_resnest_lstm()



#########################################
device = "cuda:2"  
model_save_path = 'densenet/sl10_lr5e-5_6-25train/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path
###########################################


# train M2
train_model(MODEL,(train_dataset_50),(train_num_each_50),(val_dataset_50),(val_num_each_50))

num train start idx :  13980
num of all train use: 139800
num of all valid use:  22780
Train progress: 100.0% [139800/139800]
Val progress: 100.0% [22780/22780]
epoch:    0 train in: 10m33s train loss(phase): 0.0642 train accu(phase): 0.7127 valid in:  0m30s valid loss(phase): 0.0765 valid accu(phase): 0.6932
best_epoch 0


In [43]:
sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 1 
workers = 2   
learning_rate = 5e-5
MODEL = M3_resnest_lstm() # M2_densenet_lstm() /  M3_resnest_lstm()



#########################################
device = "cuda:2"  
model_save_path = 'resnest/sl10_lr1e-5_6-25train/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path
###########################################


# train M3_resnest_lstm()
train_model(MODEL,(train_dataset_50),(train_num_each_50),(val_dataset_50),(val_num_each_50))

num train start idx :  13980
num of all train use: 139800
num of all valid use:  22780
Train progress: 100.0% [139800/139800]
Val progress: 100.0% [22780/22780]
epoch:    0 train in: 10m23s train loss(phase): 0.0627 train accu(phase): 0.7154 valid in:  0m33s valid loss(phase): 0.0849 valid accu(phase): 0.6637
best_epoch 0


In [58]:
# 3 training videos, 2 val videos
train_dataset_5, train_num_each_5, \
val_dataset_5, val_num_each_5 = get_dataset('../../gen_datasets/train_val_paths_labels_3.pkl')

train_paths_20  :   1456
train_labels_20 :   1456
valid_paths_5  :    867
valid_labels_5 :    867


In [65]:
sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 1 
workers = 2   
learning_rate = 5e-5
MODEL = M1_resnet_lstm() # M2_densenet_lstm() /  M3_resnest_lstm()



#########################################
device = "cuda:1"  
model_save_path = 'resnet/sl10_lr5e-5_6-25train/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path
###########################################


# train M1
train_model(MODEL,(train_dataset_5),(train_num_each_5),(val_dataset_5),(val_num_each_5))

num train start idx :   1429
num of all train use:  14290
num of all valid use:   8490
Train progress: 100.0% [14290/14290]
Val progress: 100.0% [8490/8490]
epoch:    0 train in:  1m 1s train loss(phase): 0.1216 train accu(phase): 0.5430 valid in:  0m11s valid loss(phase): 0.0853 valid accu(phase): 0.6737
best_epoch 0


'Complete'