In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import pandas as pd
from PIL import Image
import time
from tqdm import tqdm

In [2]:
from MobileNetV2 import MobileNetV2

In [3]:
print('Is cuda enabled on this pc?  {}'.format(torch.cuda.is_available())) 

Is cuda enabled on this pc?  True


In [4]:
net = MobileNetV2(n_class=51)
if torch.cuda.is_available():
    net = net.cuda()
    loaded_dict = torch.load('mobilenet_v2.pth.tar') # add map_location='cpu' if no gpu
else:
    loaded_dict = torch.load('mobilenet_v2.pth.tar',map_location='cpu')
state_dict = {k: v for k, v in loaded_dict.items() if k in net.state_dict()}
state_dict["classifier.1.weight"] = net.state_dict()["classifier.1.weight"]
state_dict["classifier.1.bias"] = net.state_dict()["classifier.1.bias"]
net.load_state_dict(state_dict)

In [5]:
class BatchData(Dataset):

    def format_images(self, path , datatype, batch_index):
        path_prefix = '{}/{}/batch{}/'.format(path,datatype,batch_index)
        table = pd.read_csv(path_prefix + 'label.csv', index_col = 0)
        data_list  =[path_prefix+filename for filename in  table['file name'].tolist()]
        label_list = table['label'].tolist()
        
        return data_list,label_list

    def __init__(self, path, datatype, batch_index, transforms):
        self.transforms = transforms
        self.data_list, self.label_list = self.format_images( path,datatype, batch_index)

        #print a summary
        print('Load {} batch {} have {} images '.format(datatype,batch_index,len(self.data_list)))

    def __getitem__(self, idx):
        img = self.data_list[idx]
        img = Image.open(img)
        label = int(self.label_list[idx])
        img = self.transforms(img)
        return img, label,self.data_list[idx].split('/')[-1].split('.')[0]
    
    def __len__(self):
        return len(self.data_list)


In [6]:
trans = transforms.Compose([
#                     transforms.Resize((300,300)),
                    transforms.RandomSizedCrop(255),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                    ])

train_batch_list = [BatchData('./dataset/','train', i, trans) for i in range(1,10)]
valid_batch_list = [BatchData('./dataset/','validation', i, trans) for i in range(1,10)]
test_batch_list  = [BatchData('./dataset/','test', i, trans) for i in range(1,10)]


train_loader_list = [torch.utils.data.DataLoader(batch, batch_size=16,shuffle=True, num_workers=2)
                         for batch in train_batch_list]
valid_loader_list = [torch.utils.data.DataLoader(batch, batch_size=16,shuffle=False, num_workers=2)
                         for batch in valid_batch_list]
test_loader_list = [torch.utils.data.DataLoader(batch, batch_size=16,shuffle=False, num_workers=2)
                         for batch in test_batch_list]

Load train batch 1 have 10550 images 
Load train batch 2 have 10550 images 
Load train batch 3 have 10550 images 
Load train batch 4 have 10550 images 
Load train batch 5 have 10550 images 
Load train batch 6 have 10550 images 
Load train batch 7 have 10553 images 
Load train batch 8 have 10550 images 
Load train batch 9 have 10552 images 
Load validation batch 1 have 3517 images 
Load validation batch 2 have 3517 images 
Load validation batch 3 have 3517 images 
Load validation batch 4 have 3517 images 
Load validation batch 5 have 3517 images 
Load validation batch 6 have 3517 images 
Load validation batch 7 have 3517 images 
Load validation batch 8 have 3517 images 
Load validation batch 9 have 3518 images 
Load test batch 1 have 3517 images 
Load test batch 2 have 3517 images 
Load test batch 3 have 3517 images 
Load test batch 4 have 3517 images 
Load test batch 5 have 3517 images 
Load test batch 6 have 3517 images 
Load test batch 7 have 3518 images 
Load test batch 8 have 3517 

  "please use transforms.RandomResizedCrop instead.")


In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

In [8]:
trainloader = train_loader_list[0]

In [9]:
def feed(dataloader,is_training,num_epochs = 50, validloader =None, is_valid = False, is_save_csv = False, csv_name = None):
    
    losses = list()
    acces = list()
    filename = list()
    label_gt = list()
    label_predict = list()
    
    since = time.time()
    start = time.time()
    if is_training:
        net.train()
    else:
        num_epochs = 1
        net.eval()

    for epoch in range(num_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        train_acc = 0.0
        i=0
        for data in dataloader:
            i+=1
            # get the inputs
            inputs, labels, file = data
            inputs, labels = inputs.to(device), labels.to(device)
            #print(inputs)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)

            loss = criterion(outputs, labels)
            
            if is_training:
                loss.backward()
                optimizer.step()

            # print statistics
            running_loss += loss.item()

            _, pred = outputs.max(1)

            num_correct = (pred == labels).sum().item()
            acc = num_correct / inputs.shape[0]
            train_acc += acc
            #For csv output
            label_predict.extend(pred.tolist())
            label_gt.extend(labels.tolist())
            filename.extend(list(file))
        
        acces.append(train_acc / len(dataloader))
        
        time_elapsed = time.time() - since
        since = time.time()

        if is_training:
            valid_acc, valid_loss = valid(validloader)
            print('epoch{}/{} time:{:.0f}m {:.0f}s train_acc:{:.3f} train_loss:{:.4f} valid_acc:{:.3f} valid loss:{:.3f}'.format(
                epoch+1,
                num_epochs,time_elapsed // 60,
                time_elapsed % 60,
                train_acc / len(dataloader),
                running_loss/len(dataloader),
                valid_acc,
                valid_loss))
        elif not is_valid:
            print('epoch{}/{} time:{:.0f}m {:.0f}s acc:{:.3f} loss:{:.4f}'.format(epoch+1,
                                                                      num_epochs,time_elapsed // 60,
                                                                      time_elapsed % 60,
                                                                      train_acc / len(dataloader),
                                                                      running_loss/len(dataloader)))
            
        if is_save_csv and csv_name != None:
            column = {'file': filename, 
            'label_gt': label_gt, 
            'label_predict': label_predict}  
            log = pd.DataFrame(column) 
            print('write test csv to {}'.format(csv_name))
            log.to_csv(csv_name)
    if is_training:
        time_elapsed = time.time() - start
        print('Complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        
    return train_acc / len(dataloader), running_loss/len(dataloader)

In [10]:
def valid(validloader):
    return feed(validloader,is_training = False, is_valid = True)
def test(testloader,is_save_csv = False,csv_name = None):
    feed(testloader,is_training = False, is_save_csv = is_save_csv, csv_name = csv_name)
def train(trainloader,validloader):
    feed(trainloader,validloader = validloader,is_training = True, num_epochs = 150)
def finetune(trainloader,validloader):
    feed(trainloader,validloader = validloader,is_training = True, num_epochs = 30)

In [11]:
task = 0 
output_path = './output'
trainloader = train_loader_list[task]
validloader = valid_loader_list[task]
print('>> Training in task{}'.format(task+1))
train(trainloader,validloader)
save_path = 'train.pth'
print('>> Save initial trained model to: {}'.format(save_path))
torch.save(net.state_dict(), save_path)
for test_task in range(9):
    print('[Test in task{}]:'.format(test_task+1))
    test(test_loader_list[test_task], is_save_csv = True, csv_name = '{}/test_batch{}.csv'.format(output_path,test_task+1))

# state_dict = torch.load('./model/finetuned_2.pth')
# net.load_state_dict(state_dict)


for tune_task in range(2,9):
    print('>> Finetune in task{}'.format(tune_task+1))
    finetune(train_loader_list[tune_task],valid_loader_list[tune_task])
    save_path = 'finetuned_{}.pth'.format(tune_task+1)
    print('>> Save finetuned model to: {}'.format(save_path))
    torch.save(net.state_dict(), save_path)
    for test_task in range(9):
        print('[Test in task{}]:'.format(test_task+1))
        test(test_loader_list[test_task])

>> Finetune in task3
epoch1/30 time:0m 35s train_acc:0.868 train_loss:0.4877 valid_acc:0.884 valid loss:0.332
epoch2/30 time:0m 39s train_acc:0.886 train_loss:0.3315 valid_acc:0.898 valid loss:0.294
epoch3/30 time:0m 39s train_acc:0.897 train_loss:0.2950 valid_acc:0.896 valid loss:0.301
epoch4/30 time:0m 39s train_acc:0.906 train_loss:0.2719 valid_acc:0.893 valid loss:0.301
epoch5/30 time:0m 39s train_acc:0.901 train_loss:0.2849 valid_acc:0.902 valid loss:0.280
epoch6/30 time:0m 39s train_acc:0.906 train_loss:0.2752 valid_acc:0.920 valid loss:0.228
epoch7/30 time:0m 39s train_acc:0.907 train_loss:0.2728 valid_acc:0.908 valid loss:0.262
epoch8/30 time:0m 39s train_acc:0.902 train_loss:0.2717 valid_acc:0.909 valid loss:0.257
epoch9/30 time:0m 39s train_acc:0.911 train_loss:0.2527 valid_acc:0.908 valid loss:0.255
epoch10/30 time:0m 40s train_acc:0.914 train_loss:0.2490 valid_acc:0.907 valid loss:0.265
epoch11/30 time:0m 39s train_acc:0.913 train_loss:0.2516 valid_acc:0.917 valid loss:0.24

epoch19/30 time:0m 40s train_acc:0.914 train_loss:0.2364 valid_acc:0.928 valid loss:0.214
epoch20/30 time:0m 40s train_acc:0.919 train_loss:0.2259 valid_acc:0.911 valid loss:0.240
epoch21/30 time:0m 40s train_acc:0.928 train_loss:0.2036 valid_acc:0.922 valid loss:0.225
epoch22/30 time:0m 40s train_acc:0.919 train_loss:0.2287 valid_acc:0.914 valid loss:0.241
epoch23/30 time:0m 40s train_acc:0.928 train_loss:0.2082 valid_acc:0.932 valid loss:0.195
epoch24/30 time:0m 40s train_acc:0.922 train_loss:0.2179 valid_acc:0.927 valid loss:0.214
epoch25/30 time:0m 40s train_acc:0.924 train_loss:0.2165 valid_acc:0.910 valid loss:0.251
epoch26/30 time:0m 40s train_acc:0.927 train_loss:0.2109 valid_acc:0.930 valid loss:0.195
epoch27/30 time:0m 40s train_acc:0.927 train_loss:0.2125 valid_acc:0.926 valid loss:0.214
epoch28/30 time:0m 40s train_acc:0.925 train_loss:0.2046 valid_acc:0.927 valid loss:0.211
epoch29/30 time:0m 40s train_acc:0.923 train_loss:0.2181 valid_acc:0.897 valid loss:0.311
epoch30/30

epoch1/1 time:0m 8s acc:0.882 loss:0.3705
>> Finetune in task8
epoch1/30 time:1m 25s train_acc:0.857 train_loss:0.5942 valid_acc:0.906 valid loss:0.313
epoch2/30 time:0m 58s train_acc:0.901 train_loss:0.3015 valid_acc:0.899 valid loss:0.279
epoch3/30 time:0m 39s train_acc:0.906 train_loss:0.2742 valid_acc:0.910 valid loss:0.254
epoch4/30 time:0m 40s train_acc:0.904 train_loss:0.2777 valid_acc:0.916 valid loss:0.236
epoch5/30 time:0m 40s train_acc:0.912 train_loss:0.2607 valid_acc:0.898 valid loss:0.278
epoch6/30 time:0m 40s train_acc:0.913 train_loss:0.2470 valid_acc:0.912 valid loss:0.250
epoch7/30 time:0m 40s train_acc:0.912 train_loss:0.2496 valid_acc:0.918 valid loss:0.218
epoch8/30 time:0m 40s train_acc:0.916 train_loss:0.2397 valid_acc:0.911 valid loss:0.250
epoch9/30 time:0m 40s train_acc:0.921 train_loss:0.2260 valid_acc:0.918 valid loss:0.233
epoch10/30 time:0m 40s train_acc:0.923 train_loss:0.2202 valid_acc:0.913 valid loss:0.246
epoch11/30 time:0m 40s train_acc:0.924 train_l