In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Sampler
from PIL import Image, ImageOps
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from resnest.torch import resnest50
from NLBlock_TVL import TVL,NLBlock

import time
import pickle
import numpy as np
from torchvision.transforms import Lambda
import argparse
import copy
import random
import numbers
from torch.utils.tensorboard import SummaryWriter
from sklearn import metrics
import os

In [2]:
if torch.cuda.is_available()==True:
    device="cuda:2"
else:
    device ="cpu"
    
print(torch.cuda.is_available())

True


# 1. Models:M4, M5, M6

In [2]:
# Model 4 CataRCNet: resnet + lstm + non-local 
class M4_resnet_lstm_nl(torch.nn.Module):
    def __init__(self):
        super(M4_resnet_lstm_nl, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.share = torch.nn.Sequential()
        self.share.add_module("conv1", resnet.conv1)
        self.share.add_module("bn1", resnet.bn1)
        self.share.add_module("relu", resnet.relu)
        self.share.add_module("maxpool", resnet.maxpool)
        self.share.add_module("layer1", resnet.layer1)
        self.share.add_module("layer2", resnet.layer2)
        self.share.add_module("layer3", resnet.layer3)
        self.share.add_module("layer4", resnet.layer4)
        self.share.add_module("avgpool", resnet.avgpool)
        self.lstm = nn.LSTM(2048, 512, batch_first=True)
        self.fc_c = nn.Linear(512, 19) #7
        self.fc_h_c = nn.Linear(1024, 512)
        self.nl_block = NLBlock()
        self.dropout = nn.Dropout(p=0.5)

        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])
        init.xavier_uniform_(self.fc_c.weight)
        init.xavier_uniform_(self.fc_h_c.weight)

    def forward(self, x, long_feature=None):
        x = x.view(-1, 3, 216,216)
        x = self.share.forward(x)
        x = x.view(-1, sequence_length, 2048)
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)
        y = y.contiguous().view(-1, 512)
        y = y[sequence_length - 1::sequence_length]

        y_1 = self.nl_block(y, long_feature)
        y = torch.cat([y, y_1], dim=1)
        y = self.dropout(self.fc_h_c(y))
        y = F.relu(y)
        y = self.fc_c(y)
        return y

In [3]:
# M5: densenet + lstm + non-local
class M5_densenet_lstm_nl(torch.nn.Module):
    def __init__(self):
        super(M5_densenet_lstm_nl, self).__init__()
        densenet = models.densenet169(pretrained=True) #pretrained=True
        self.share = torch.nn.Sequential()
        self.share.add_module("features", densenet.features)
        #self.share.add_module("avgpool", resnet.avgpool)
        self.avg = nn.AvgPool2d(6)

        # self.share.add_module("classifier", resnet.classifier)
        
        #self.fc_1 = nn.Linear(9216, 4096)
         
        self.lstm = nn.LSTM(1664, 512, batch_first=True)
#        self.lstm = nn.LSTM(2028, 512, batch_first=True)
        #self.fc = nn.Linear(512, 19)
        self.fc_c = nn.Linear(512, 19) #7
        self.fc_h_c = nn.Linear(1024, 512)
        self.nl_block = NLBlock()
        self.dropout = nn.Dropout(p=0.5)

        
        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])
        init.xavier_uniform_(self.fc_c.weight)
        init.xavier_uniform_(self.fc_h_c.weight)

    def forward(self, x,long_feature):
        x = x.view(-1, 3, 216, 216)
        x = self.share.forward(x) # ([100, 1664, 6, 6])   # ([100,2048,1,1])
        x = self.avg(x)
        x = x.view(-1, sequence_length, 1664)  
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)
        y = y.contiguous().view(-1, 512)
        y = y[sequence_length - 1::sequence_length]
        
        
        y_1 = self.nl_block(y, long_feature)
        y = torch.cat([y, y_1], dim=1)
        y = self.dropout(self.fc_h_c(y))
        y = F.relu(y)
        y = self.fc_c(y)
        return y

In [4]:
# Model 6  resnest + lstm + non-local 
class M6_resnest_lstm_nl(torch.nn.Module):
    def __init__(self):
        super(M6_resnest_lstm_nl, self).__init__()
        resnet = resnest50(pretrained=True)
        self.share = torch.nn.Sequential()
        self.share.add_module("conv1", resnet.conv1)
        self.share.add_module("bn1", resnet.bn1)
        self.share.add_module("relu", resnet.relu)
        self.share.add_module("maxpool", resnet.maxpool)
        self.share.add_module("layer1", resnet.layer1)
        self.share.add_module("layer2", resnet.layer2)
        self.share.add_module("layer3", resnet.layer3)
        self.share.add_module("layer4", resnet.layer4)
        self.share.add_module("avgpool", resnet.avgpool)
        self.lstm = nn.LSTM(2048, 512, batch_first=True)
        self.fc_c = nn.Linear(512, 19) #7
        self.fc_h_c = nn.Linear(1024, 512)
        self.nl_block = NLBlock()
        self.dropout = nn.Dropout(p=0.5)

        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])
        init.xavier_uniform_(self.fc_c.weight)
        init.xavier_uniform_(self.fc_h_c.weight)

    def forward(self, x, long_feature=None):
        x = x.view(-1, 3, 216,216)
        x = self.share.forward(x)
        x = x.view(-1, sequence_length, 2048)
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)
        y = y.contiguous().view(-1, 512)
        y = y[sequence_length - 1::sequence_length]

        y_1 = self.nl_block(y, long_feature)
        y = torch.cat([y, y_1], dim=1)
        y = self.dropout(self.fc_h_c(y))
        y = F.relu(y)
        y = self.fc_c(y)
        return y

In [16]:
M4 = M4_resnet_lstm_nl()

In [15]:
M5 = M5_densenet_lstm_nl()

In [18]:
M6 = M6_resnest_lstm_nl()

# 2. LFB Models: 

In [5]:
#long feature bank bank

# resnet+lstm lfb  used for M4, M7  
class LFB_resnet_lstm(torch.nn.Module):
    def __init__(self):
        super(LFB_resnet_lstm, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.share = torch.nn.Sequential()
        self.share.add_module("conv1", resnet.conv1)
        self.share.add_module("bn1", resnet.bn1)
        self.share.add_module("relu", resnet.relu)
        self.share.add_module("maxpool", resnet.maxpool)
        self.share.add_module("layer1", resnet.layer1)
        self.share.add_module("layer2", resnet.layer2)
        self.share.add_module("layer3", resnet.layer3)
        self.share.add_module("layer4", resnet.layer4)
        self.share.add_module("avgpool", resnet.avgpool)
        self.lstm = nn.LSTM(2048, 512, batch_first=True)

        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])

    def forward(self, x):
        x = x.view(-1, 3, 216,216)
        x = self.share.forward(x)
        x = x.view(-1, sequence_length, 2048)
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)
        y = y.contiguous().view(-1, 512)
        y = y[sequence_length - 1::sequence_length]
        return y

In [6]:
# long feature bank used for M5, M8
class LFB_densenet_lstm(torch.nn.Module):
    def __init__(self):
        super(LFB_densenet_lstm, self).__init__()
        densenet = models.densenet169(pretrained=True) #pretrained=True
        self.share = torch.nn.Sequential()
        self.share.add_module("features", densenet.features)
        #self.share.add_module("avgpool", resnet.avgpool)
        self.avg = nn.AvgPool2d(6)

        # self.share.add_module("classifier", resnet.classifier)      
        #self.fc_1 = nn.Linear(9216, 4096)
        
        # 
        self.lstm = nn.LSTM(1664, 512, batch_first=True)
#        self.lstm = nn.LSTM(2028, 512, batch_first=True)
        #self.fc = nn.Linear(512, 19)

        

        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])
        #init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        x = x.view(-1, 3, 216, 216)
        x = self.share.forward(x) # ([100, 1664, 6, 6])   # ([100,2048,1,1])
        x = self.avg(x)
        x = x.view(-1, sequence_length, 1664)  
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)
        y = y.contiguous().view(-1, 512)
        y = y[sequence_length - 1::sequence_length]
        return y

In [7]:
#long feature bank bank

# resnet+lstm lfb  used for M6,M9
class LFB_resnest_lstm(torch.nn.Module):
    def __init__(self):
        super(LFB_resnest_lstm, self).__init__()
        resnest = resnest50(pretrained=True)
        self.share = torch.nn.Sequential()
        self.share.add_module("conv1", resnest.conv1)
        self.share.add_module("bn1", resnest.bn1)
        self.share.add_module("relu", resnest.relu)
        self.share.add_module("maxpool", resnest.maxpool)
        self.share.add_module("layer1", resnest.layer1)
        self.share.add_module("layer2", resnest.layer2)
        self.share.add_module("layer3", resnest.layer3)
        self.share.add_module("layer4", resnest.layer4)
        self.share.add_module("avgpool", resnest.avgpool)
        self.lstm = nn.LSTM(2048, 512, batch_first=True)

        init.xavier_normal_(self.lstm.all_weights[0][0])
        init.xavier_normal_(self.lstm.all_weights[0][1])

    def forward(self, x):
        x = x.view(-1, 3, 216,216)
        x = self.share.forward(x)
        x = x.view(-1, sequence_length, 2048)
        self.lstm.flatten_parameters()
        y, _ = self.lstm(x)
        y = y.contiguous().view(-1, 512)
        y = y[sequence_length - 1::sequence_length]
        return y

# 3. Loss function

In [8]:
class WeightedCrossEntropy(torch.nn.Module):
    '''
    WCE
    '''       
     # 6-25
    def __init__(self, weight=torch.Tensor([0.0033, 0.4182, 0.1321, 0.0234, 0.0344, 0.0146, 0.0428, 0.0140, 0.0092,
        0.0272, 0.0096, 0.0323, 0.0341, 0.0508, 0.0151, 0.0160, 0.0365, 0.0738,
        0.0128])):
        super(WeightedCrossEntropy, self).__init__()
        
        weight = weight.to(device)
        self.weighted_cross_entropy = nn.CrossEntropyLoss(weight=weight)
        
    def forward(self, inputs, target):
        return self.weighted_cross_entropy.forward(inputs, target)
    

# 4. Dataset

In [9]:
def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

class CataractsDataset(Dataset):
    def __init__(self, file_paths,file_labels, transform=None,loader=pil_loader):
        self.file_paths = file_paths
        self.file_labels_phase = file_labels[:,0]
        self.transform = transform
        self.loader = loader
        
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
        img_names = self.file_paths[index]
        labels = self.file_labels_phase[index]
        imgs = self.loader(img_names)
        if self.transform is not None:
            imgs = self.transform(imgs)

        return imgs, labels, index

In [10]:
def get_dataset(data_path):
    with open(data_path, 'rb') as f:
        train_test_paths_labels = pickle.load(f)
    train_paths_50 = train_test_paths_labels[0]
    val_paths_50 = train_test_paths_labels[1]
    train_labels_50 = train_test_paths_labels[2]
    val_labels_50 = train_test_paths_labels[3]
    train_num_each_50 = train_test_paths_labels[4]
    val_num_each_50 = train_test_paths_labels[5]

    print('train_paths_20  : {:6d}'.format(len(train_paths_50)))
    print('train_labels_20 : {:6d}'.format(len(train_labels_50)))
    print('valid_paths_5  : {:6d}'.format(len(val_paths_50)))
    print('valid_labels_5 : {:6d}'.format(len(val_labels_50)))

    # train_labels_19 = np.asarray(train_labels_19, dtype=np.int64) yilin comment
    train_labels_50 = np.asarray(train_labels_50, dtype=np.int64)
    val_labels_50 = np.asarray(val_labels_50, dtype=np.int64)
    
    train_transforms = transforms.Compose([
            transforms.CenterCrop(216),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(5),
            transforms.ToTensor()
        ])
    
    test_transforms = transforms.Compose([
            transforms.CenterCrop(216),
            transforms.ToTensor(),

        ])

    train_dataset_50 = CataractsDataset(train_paths_50, train_labels_50, train_transforms)
    val_dataset_50 = CataractsDataset(val_paths_50, val_labels_50, test_transforms)
    train_dataset_50_LFB =CataractsDataset(train_paths_50, train_labels_50, test_transforms)
    
    return (train_dataset_50,train_dataset_50_LFB), train_num_each_50,val_dataset_50, val_num_each_50

# 4. Load long term feature bank

In [32]:
# Long Term Feature bank
g_LFB_train = np.zeros(shape=(0, 512))
g_LFB_val = np.zeros(shape=(0, 512))

In [11]:
# for sliding window
def get_start_idx(sequence_length, list_each_length):
    count = 0
    idx = []
    for i in range(len(list_each_length)):
        for j in range(count, count + (list_each_length[i] + 1 - sequence_length)):
            idx.append(j)
        count += list_each_length[i]
    return idx

In [12]:
def get_long_feature(start_index_list, dict_start_idx_LFB, lfb):
    long_feature = []
    for j in range(len(start_index_list)):
        long_feature_each = []
        
        # 上一个存在feature的index
        last_LFB_index_no_empty = dict_start_idx_LFB[int(start_index_list[j])]
        
        # 
        for k in range(LFB_length):
            LFB_index = (start_index_list[j] - k - 1)
            if int(LFB_index) in dict_start_idx_LFB:                
                LFB_index = dict_start_idx_LFB[int(LFB_index)]
                long_feature_each.append(lfb[LFB_index])
                last_LFB_index_no_empty = LFB_index
            else:
                long_feature_each.append(lfb[last_LFB_index_no_empty])
            
        long_feature.append(long_feature_each)
    return long_feature

In [13]:
class SequenceSampler(Sampler):
    def __init__(self, data_source, idx):
        super().__init__(data_source)
        self.data_source = data_source
        self.idx = idx

    def __iter__(self):
        return iter(self.idx)

    def __len__(self):
        return len(self.idx)


# 5. Training M4, M5, M6

In [18]:
# Long Term Feature bank
g_LFB_train = np.zeros(shape=(0, 512))
g_LFB_val = np.zeros(shape=(0, 512))

def train_model(model,train_dataset_2, train_num_each, val_dataset, val_num_each):
    # TensorBoard
    writer = SummaryWriter(tensorboard_path)
    
    (train_dataset,train_dataset_LFB) = train_dataset_2
    
    # choose start index for sequence 
    train_useful_start_idx = get_start_idx(sequence_length, train_num_each)
    val_useful_start_idx = get_start_idx(sequence_length, val_num_each)

    num_train_we_use = len(train_useful_start_idx)
    num_val_we_use = len(val_useful_start_idx)
    
    # choose start index for feature bank
 
    
    train_useful_start_idx_LFB = train_useful_start_idx
    val_useful_start_idx_LFB = val_useful_start_idx

    num_train_we_use_LFB = num_train_we_use 
    num_val_we_use_LFB = num_val_we_use
    
    
    train_idx = []
    for i in range(num_train_we_use):
        for j in range(sequence_length):
            train_idx.append(train_useful_start_idx[i] + j)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_useful_start_idx[i] + j)

    num_train_all = len(train_idx)
    num_val_all = len(val_idx)
    
    
    train_idx_LFB = train_idx
    val_idx_LFB = val_idx
    
    dict_index, dict_value = zip(*list(enumerate(train_useful_start_idx_LFB)))
    dict_train_start_idx_LFB = dict(zip(dict_value, dict_index))

    dict_index, dict_value = zip(*list(enumerate(val_useful_start_idx_LFB)))
    dict_val_start_idx_LFB = dict(zip(dict_value, dict_index))

    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('num of all train use: {:6d}'.format(num_train_all))
    print('num of all valid use: {:6d}'.format(num_val_all))
    print('num of all train LFB use: {:6d}'.format(len(train_idx_LFB)))
    print('num of all valid LFB use: {:6d}'.format(len(val_idx_LFB)))

    val_loader = DataLoader(
        val_dataset,
        batch_size=val_batch_size,
        sampler=SequenceSampler(val_dataset, val_idx),
        num_workers=workers,
        pin_memory=False
    )

    
    # load long feature bank
    global g_LFB_train
    global g_LFB_val
    print("loading features!->.........")
    
    if not load_exist_LFB:
        g_LFB_train = np.zeros(shape=(0, 512))
        g_LFB_val = np.zeros(shape=(0, 512))
        
        train_feature_loader = DataLoader(
            train_dataset_LFB,
            batch_size=val_batch_size,
            sampler=SequenceSampler(train_dataset_LFB, train_idx_LFB),
            num_workers=workers,
            pin_memory=False
        )
        val_feature_loader = DataLoader(
            val_dataset,
            batch_size=val_batch_size,
            sampler=SequenceSampler(val_dataset, val_idx_LFB),
            num_workers=workers,
            pin_memory=False
        )
        
        model_LFB.load_state_dict(torch.load(LFB_path), strict=False)
        
        model_LFB.to(device)
        model_LFB.eval()
        
        with torch.no_grad():

            for data in train_feature_loader:
                inputs, labels_phase = data[0].to(device), data[1].to(device)


                inputs = inputs.view(-1, sequence_length, 3, 216,216)
                outputs_feature = model_LFB.forward(inputs)

                for j in range(len(outputs_feature)):
                    save_feature = outputs_feature.data.cpu()[j].numpy()
                    save_feature = save_feature.reshape(1, 512)
                    g_LFB_train = np.concatenate((g_LFB_train, save_feature),axis=0)

                print("train feature length:",len(g_LFB_train))

            for data in val_feature_loader:
                
                inputs, labels_phase = data[0].to(device), data[1].to(device)


                inputs = inputs.view(-1, sequence_length, 3, 216,216)
                outputs_feature = model_LFB.forward(inputs)

                for j in range(len(outputs_feature)):
                    save_feature = outputs_feature.data.cpu()[j].numpy()
                    save_feature = save_feature.reshape(1, 512)
                    g_LFB_val = np.concatenate((g_LFB_val, save_feature), axis=0)

                print("val feature length:",len(g_LFB_val))

        print("finish!")
        g_LFB_train = np.array(g_LFB_train)
        g_LFB_val = np.array(g_LFB_val)

        # LFB_train_path_save_path = "./LFB/g_LFB_train_densenet.pkl"
        with open(LFB_train_path_save_path, 'wb') as f:
            pickle.dump(g_LFB_train, f)

        with open(LFB_val_path_save_path, 'wb') as f:
            pickle.dump(g_LFB_val, f)
    
    else:
        with open(LFB_train_path_save_path, 'rb') as f:
            g_LFB_train = pickle.load(f)

        with open(LFB_val_path_save_path, 'rb') as f:
            g_LFB_val = pickle.load(f)

        print("load completed")
        
        
    print("g_LFB_train shape:",g_LFB_train.shape)
    print("g_LFB_val shape:",g_LFB_val.shape)
        
    
    torch.cuda.empty_cache()

    #####################################
    #model = resnet_lstm()
    # model.load_state_dict(torch.load(pretrained_model_path))
    #####################################   
    
    model.load_state_dict(torch.load(LFB_path), strict=False)
    model.to(device)

    criterion_phase = WeightedCrossEntropy() #nn.CrossEntropyLoss(size_average=False)
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    
    best_model_wts = copy.deepcopy(model.state_dict()) 
    
    best_val_accuracy_phase = 0.0
    correspond_train_acc_phase = 0.0
    best_epoch = 0

    for epoch in range(epochs):
        torch.cuda.empty_cache()
        
        np.random.shuffle(train_useful_start_idx)
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_useful_start_idx[i] + j)

        train_loader = DataLoader(
            train_dataset,
            batch_size=train_batch_size,
            sampler=SequenceSampler(train_dataset, train_idx),
            num_workers=workers,
            pin_memory=False
        )

        # in training mode.
        model.train()
        train_loss_phase = 0.0
        train_corrects_phase = 0
        batch_progress = 0.0
        running_loss_phase = 0.0
        minibatch_correct_phase = 0.0
        train_start_time = time.time()
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()

            inputs, labels_phase = data[0].to(device), data[1].to(device)

            labels_phase = labels_phase[(sequence_length - 1)::sequence_length]

            start_index_list = data[2]
            start_index_list = start_index_list[0::sequence_length]
            long_feature = get_long_feature(start_index_list=start_index_list,
                                            dict_start_idx_LFB=dict_train_start_idx_LFB,
                                            lfb=g_LFB_train)

            long_feature = (torch.Tensor(long_feature)).to(device)
            
            
            
            
            inputs = inputs.view(-1, sequence_length, 3, 216,216) #224, 224)
            outputs_phase = model.forward(inputs,long_feature=long_feature)
            #outputs_phase = outputs_phase[sequence_length - 1::sequence_length]

            _, preds_phase = torch.max(outputs_phase.data, 1)
            loss_phase = criterion_phase(outputs_phase, labels_phase)

            loss = loss_phase
            loss.backward()
            optimizer.step()

            running_loss_phase += loss_phase.data.item()
            train_loss_phase += loss_phase.data.item()

            batch_corrects_phase = torch.sum(preds_phase == labels_phase.data)
            train_corrects_phase += batch_corrects_phase


            if (i+1)*train_batch_size >= num_train_all:               
                running_loss_phase = 0.0
                minibatch_correct_phase = 0.0

            batch_progress += 1
            if batch_progress*train_batch_size >= num_train_all:
                percent = 100.0
                print('Train progress: %s [%d/%d]' % (str(percent) + '%', num_train_all, num_train_all), end='\n')
            else:
                percent = round(batch_progress*train_batch_size / num_train_all * 100, 2)
                print('Train progress: %s [%d/%d]' % (str(percent) + '%', batch_progress*train_batch_size, num_train_all), end='\r')

        train_elapsed_time = time.time() - train_start_time
        train_accuracy_phase = float(train_corrects_phase) / float(num_train_all) * sequence_length
        train_average_loss_phase = train_loss_phase / num_train_all * sequence_length

        
        writer.add_scalar('train acc epoch phase',
                          float(train_accuracy_phase),epoch)
        writer.add_scalar('train loss epoch phase',
                          float(train_average_loss_phase),epoch)
        
        
        
        
        #  in evaluation mode.
        model.eval()
        val_loss_phase = 0.0
        val_corrects_phase = 0
        val_start_time = time.time()
        val_progress = 0
        val_all_preds_phase = []
        val_all_labels_phase = []

        with torch.no_grad():
            for data in val_loader:

                inputs, labels_phase = data[0].to(device), data[1].to(device)


                labels_phase = labels_phase[(sequence_length - 1)::sequence_length]

                
                start_index_list = data[2]
                start_index_list = start_index_list[0::sequence_length]
                long_feature = get_long_feature(start_index_list=start_index_list,
                                                dict_start_idx_LFB=dict_val_start_idx_LFB,
                                                lfb=g_LFB_val)

                long_feature = torch.Tensor(long_feature).to(device)

                inputs = inputs.view(-1, sequence_length, 3, 216,216)
                outputs_phase = model.forward(inputs, long_feature=long_feature)
                
                
                # outputs_phase = outputs_phase[sequence_length - 1::sequence_length]

                _, preds_phase = torch.max(outputs_phase.data, 1)
                loss_phase = criterion_phase(outputs_phase, labels_phase)

                val_loss_phase += loss_phase.data.item()

                val_corrects_phase += torch.sum(preds_phase == labels_phase.data)


                for i in range(len(preds_phase)):
                    val_all_preds_phase.append(int(preds_phase.data.cpu()[i]))
                for i in range(len(labels_phase)):
                    val_all_labels_phase.append(int(labels_phase.data.cpu()[i]))


                val_progress += 1
                if val_progress*val_batch_size >= num_val_all:
                    percent = 100.0
                    print('Val progress: %s [%d/%d]' % (str(percent) + '%', num_val_all, num_val_all), end='\n')
                else:
                    percent = round(val_progress*val_batch_size / num_val_all * 100, 2)
                    print('Val progress: %s [%d/%d]' % (str(percent) + '%', val_progress*val_batch_size, num_val_all), end='\r')

        val_elapsed_time = time.time() - val_start_time
        val_accuracy_phase = float(val_corrects_phase) / float(num_val_we_use)
        val_average_loss_phase = val_loss_phase / num_val_we_use


        writer.add_scalar('validation acc epoch phase',
                          float(val_accuracy_phase),epoch)
        writer.add_scalar('validation loss epoch phase',
                          float(val_average_loss_phase),epoch)

        print('epoch: {:4d}'
              ' train in: {:2.0f}m{:2.0f}s'
              ' train loss(phase): {:4.4f}'
              ' train accu(phase): {:.4f}'
              ' valid in: {:2.0f}m{:2.0f}s'
              ' valid loss(phase): {:4.4f}'
              ' valid accu(phase): {:.4f}'
              .format(epoch,
                      train_elapsed_time // 60,
                      train_elapsed_time % 60,
                      train_average_loss_phase,
                      train_accuracy_phase,
                      val_elapsed_time // 60,
                      val_elapsed_time % 60,
                      val_average_loss_phase,
                      val_accuracy_phase))


        # choose the best model by accuracy
        if val_accuracy_phase > best_val_accuracy_phase:
            best_val_accuracy_phase = val_accuracy_phase
            correspond_train_acc_phase = train_accuracy_phase
            #copy the best model
            best_model_wts = copy.deepcopy(model.state_dict())
            best_epoch = epoch
        if val_accuracy_phase == best_val_accuracy_phase:
            if train_accuracy_phase > correspond_train_acc_phase:
                correspond_train_acc_phase = train_accuracy_phase
                best_model_wts = copy.deepcopy(model.state_dict())
                best_epoch = epoch

        save_val_phase = int("{:4.0f}".format(best_val_accuracy_phase * 10000))
        save_train_phase = int("{:4.0f}".format(correspond_train_acc_phase * 10000))
        base_name = "lstm" \
                     + "_epoch_" + str(best_epoch) \
                     + "_length_" + str(sequence_length) \
                     + "_batch_" + str(train_batch_size) \
                     + "_train_" + str(save_train_phase) \
                     + "_val_" + str(save_val_phase)
        
        #model_save_path = 'sl10_flip1_lr5e-5/'
        
        torch.save(best_model_wts, "/media/yilin/catarcnet/best_model/"+model_save_path+base_name+".pth")
        print("best_epoch",str(best_epoch))
        # model.module.state_dict()
        
        torch.save(model.state_dict(), "/media/yilin/catarcnet/temp/"+model_save_path+ "latest_model_"+str(epoch)+".pth")
              

    return "Complete"

In [None]:
# in this experiment, train06-train25 for training, train01-train05 for validation

train_dataset_20, train_num_each_20, \
val_dataset_5, val_num_each_5 = get_dataset('../../gen_datasets/train_val_paths_labels.pkl')

train_model(MODEL,(train_dataset_20),(train_num_each_20),(val_dataset_20),(val_num_each_20))

In [None]:
# before training, we need make a dir of the best model and temp model

In [16]:
# for show example output of train process
# 3 training videos, 2 val videos
train_dataset_5, train_num_each_5, \
val_dataset_5, val_num_each_5 = get_dataset('../../gen_datasets/train_val_paths_labels_3.pkl')

train_paths_20  :   1456
train_labels_20 :   1456
valid_paths_5  :    867
valid_labels_5 :    867


In [17]:
# the first time to train, it is necessary to load feature bank

sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 3 
workers = 2   
learning_rate = 5e-5
LFB_length = 30
load_exist_LFB = True # False

MODEL = M4_resnet_lstm_nl() # M5_densenet_lstm_nl() /  M6_resnest_lstm_nl()

model_LFB = LFB_resnet_lstm()
LFB_train_path_save_path = "./LFB/g_LFB_train_resnet.pkl"
LFB_val_path_save_path = "./LFB/g_LFB_val_resnet.pkl"



#########################################
device = "cuda:2"  

model_save_path = 'non-local/pretrained_lr1e-5_L30_2fc_resnet_nl/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path

LFB_path = "/media/yilin/catarcnet/temp/resnet/sl10_lr5e-5_6-25train/latest_model_14.pth"
###########################################


# train M4  # example output using train01-03 for training, train04-train05 for validation
train_model(MODEL,(train_dataset_5),(train_num_each_5),(val_dataset_5),(val_num_each_5))



num train start idx :   1429
num of all train use:  14290
num of all valid use:   8490
num of all train LFB use:  14290
num of all valid LFB use:   8490
loading features!->.........
train feature length: 10
train feature length: 20
train feature length: 30
train feature length: 40
train feature length: 50
train feature length: 60
train feature length: 70
train feature length: 80
train feature length: 90
train feature length: 100
train feature length: 110
train feature length: 120
train feature length: 130
train feature length: 140
train feature length: 150
train feature length: 160
train feature length: 170
train feature length: 180
train feature length: 190
train feature length: 200
train feature length: 210
train feature length: 220
train feature length: 230
train feature length: 240
train feature length: 250
train feature length: 260
train feature length: 270
train feature length: 280
train feature length: 290
train feature length: 300
train feature length: 310
train feature length:

'Complete'

In [22]:
# the first time to train, it is necessary to load feature bank

sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 3 
workers = 2   
learning_rate = 5e-5
LFB_length = 30
load_exist_LFB = True # False

MODEL = M4_resnet_lstm_nl() # M5_densenet_lstm_nl() /  M6_resnest_lstm_nl()

model_LFB = LFB_resnet_lstm()
LFB_train_path_save_path = "./LFB/g_LFB_train_resnet.pkl"
LFB_val_path_save_path = "./LFB/g_LFB_val_resnet.pkl"



#########################################
device = "cuda:2"  

model_save_path = 'non-local/pretrained_lr5e-5_L30_2fc_resnet_nl/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path

LFB_path = "/media/yilin/catarcnet/temp/resnet/sl10_lr5e-5_6-25train/latest_model_14.pth"
###########################################


# train M4
train_model(MODEL,(train_dataset_5),(train_num_each_5),(val_dataset_5),(val_num_each_5))

num train start idx :   1429
num of all train use:  14290
num of all valid use:   8490
num of all train LFB use:  14290
num of all valid LFB use:   8490
loading features!->.........
load completed
g_LFB_train shape: (1429, 512)
g_LFB_val shape: (849, 512)
Train progress: 100.0% [14290/14290]
Val progress: 100.0% [8490/8490]
epoch:    0 train in:  0m58s train loss(phase): 0.1670 train accu(phase): 0.4430 valid in:  0m11s valid loss(phase): 0.1155 valid accu(phase): 0.5112
best_epoch 0
Train progress: 100.0% [14290/14290]
Val progress: 100.0% [8490/8490]
epoch:    1 train in:  0m58s train loss(phase): 0.0582 train accu(phase): 0.7124 valid in:  0m11s valid loss(phase): 0.0940 valid accu(phase): 0.6160
best_epoch 1
Train progress: 100.0% [14290/14290]
Val progress: 100.0% [8490/8490]
epoch:    2 train in:  0m58s train loss(phase): 0.0370 train accu(phase): 0.7873 valid in:  0m12s valid loss(phase): 0.1253 valid accu(phase): 0.5524
best_epoch 1


'Complete'

In [25]:
sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 3 
workers = 2   
learning_rate = 1e-5
LFB_length = 30
load_exist_LFB = False # False

MODEL = M5_densenet_lstm_nl() # M5_densenet_lstm_nl() /  M6_resnest_lstm_nl()

model_LFB = LFB_densenet_lstm()
LFB_train_path_save_path = "./LFB/g_LFB_train_densenet.pkl"
LFB_val_path_save_path = "./LFB/g_LFB_val_densenet.pkl"



#########################################
device = "cuda:2"  

model_save_path = 'non-local/pretrained_lr1e-5_L30_2fc_densenet_nl/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path

LFB_path = "/media/yilin/catarcnet/temp/densenet/sl10_lr1e-5_6-25train/latest_model_6.pth"
###########################################


# train M5
train_model(MODEL,(train_dataset_5),(train_num_each_5),(val_dataset_5),(val_num_each_5))


num train start idx :   1429
num of all train use:  14290
num of all valid use:   8490
num of all train LFB use:  14290
num of all valid LFB use:   8490
loading features!->.........
train feature length: 10
train feature length: 20
train feature length: 30
train feature length: 40
train feature length: 50
train feature length: 60
train feature length: 70
train feature length: 80
train feature length: 90
train feature length: 100
train feature length: 110
train feature length: 120
train feature length: 130
train feature length: 140
train feature length: 150
train feature length: 160
train feature length: 170
train feature length: 180
train feature length: 190
train feature length: 200
train feature length: 210
train feature length: 220
train feature length: 230
train feature length: 240
train feature length: 250
train feature length: 260
train feature length: 270
train feature length: 280
train feature length: 290
train feature length: 300
train feature length: 310
train feature length:

'Complete'

In [26]:
sequence_length = 10    # the length of input clip 
train_batch_size = 100  # batch size 
val_batch_size = 100     
epochs = 2 
workers = 2   
learning_rate = 5e-5
LFB_length = 30
load_exist_LFB = False # True

MODEL = M6_resnest_lstm_nl() # M5_densenet_lstm_nl() /  M6_resnest_lstm_nl()

model_LFB = LFB_resnest_lstm()
LFB_train_path_save_path = "./LFB/g_LFB_train_resnest.pkl"
LFB_val_path_save_path = "./LFB/g_LFB_val_resnest.pkl"



#########################################
device = "cuda:2"  

model_save_path = 'non-local/pretrained_lr1e-5_L30_2fc_resnest_nl/' # 'resnet/sl10_lr5e-5_6-25train/' cuda:0
tensorboard_path = 'runs/' + model_save_path

LFB_path = "/media/yilin/catarcnet/temp/resnest/sl10_lr1e-5_6-25train/latest_model_12.pth"
###########################################


# train M6
train_model(MODEL,(train_dataset_5),(train_num_each_5),(val_dataset_5),(val_num_each_5))



num train start idx :   1429
num of all train use:  14290
num of all valid use:   8490
num of all train LFB use:  14290
num of all valid LFB use:   8490
loading features!->.........
train feature length: 10
train feature length: 20
train feature length: 30
train feature length: 40
train feature length: 50
train feature length: 60
train feature length: 70
train feature length: 80
train feature length: 90
train feature length: 100
train feature length: 110
train feature length: 120
train feature length: 130
train feature length: 140
train feature length: 150
train feature length: 160
train feature length: 170
train feature length: 180
train feature length: 190
train feature length: 200
train feature length: 210
train feature length: 220
train feature length: 230
train feature length: 240
train feature length: 250
train feature length: 260
train feature length: 270
train feature length: 280
train feature length: 290
train feature length: 300
train feature length: 310
train feature length:

'Complete'