# Import Library

In [None]:
import os
from tqdm import tqdm
from tqdm import tqdm_notebook
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
import torchvision

# Custom DataLoad for pytorch

In [None]:
class sample(object):
    def __init__(self,_id,img_path,ori_label,label, img):
        self._id = _id
        self.img_path = img_path
        self.label = label
        self.img = img

class load_dataset(object):
    def __init__(self,file_path,transforms=None,filter_classes=[]):
        self.file_path = file_path
        self.dataset = []
        self.transforms = transforms
        self.filter_classes = filter_classes
        _id = 0
        for label,folder in tqdm(enumerate(sorted(os.listdir(file_path)))):
            if len(filter_classes) != 0:
                if label in filter_classes:
                    folder_path = os.path.join(file_path,folder)
                    for file in sorted(os.listdir(folder_path)):
                        _,ext = file.split('.')
                        if ext == 'png':
                            img_path = os.path.join(folder_path,file)
                            img = self.loader(img_path)
                            if self.transforms is not None:
                                img = self.transforms(img)
                            temp = sample(_id,img_path,label,label,img)
                            self.dataset.append(temp)
                            _id += 1
            else:
                folder_path = os.path.join(file_path,folder)
                for file in sorted(os.listdir(folder_path)):
                    _,ext = file.split('.')
                    if ext == 'png':
                        img_path = os.path.join(folder_path,file)
                        img = self.loader(img_path)
                        if self.transforms is not None:
                            img = self.transforms(img)
                        temp = sample(_id,img_path,label,label,img)
                        self.dataset.append(temp)
                        _id += 1
    
    def __getitem__(self, index):
        target = torch.LongTensor([self.dataset[index].label])
        ids = self.dataset[index]._id
        return ids,self.dataset[index],self.dataset[index].img,target
    
    def __len__(self):
        return len(self.dataset)
    
    def loader(self,img_path):
        with open(img_path,'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    def set_label(self,index,label):
        self.dataset[index].label = label
        
from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler


class data_loader(object):
    def __init__(self, dataset, batch_size = 1, shuffle = False):
        self.dataset = dataset
        self.batch_sampler = None
        self.batch_size = batch_size
        self.drop_last = False
        if shuffle == True:
            sampler = RandomSampler(self.dataset)
        else:
            sampler = SequentialSampler(self.dataset)
  
        self.batch_sampler = random_batch_sampler(self.dataset,self.batch_size,shuffle,drop_last = False)
        self.sampler = sampler
        
    def __iter__(self):
        batch=[]
        for idxes in self.batch_sampler:
            batch = []
            ids = []
            imgs = torch.Tensor()
            labels = torch.LongTensor()
            for idx in idxes:
                _id,_,img,label = self.dataset[idx]
                ids.extend([_id])
                imgs = torch.cat((imgs,img),0)
                labels = torch.cat((labels,label),0)
            batch.append(ids)
            batch.append(imgs.unsqueeze(1))
            batch.append(labels)
            yield batch

                
    def __len__(self):
        return int(len(self.dataset)/self.batch_size)+1

    
class random_batch_sampler(object):
    def __init__(self,dataset,batch_size,shuffle,drop_last):
        self.n = len(dataset)
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.shuffle = True
    def __iter__(self):
        self.cur_idx = 0
        if self.shuffle:
            self.ids = torch.randperm(self.n).tolist()
        else:
            self.ids = torch.arange(self.n).tolist()
        while self.cur_idx < self.n:
            result = self.ids[self.cur_idx:self.cur_idx+self.batch_size]
            self.cur_idx += self.batch_size
            yield result

# DataLoad

In [None]:
train_path = os.path.join(os.getcwd(),'edge_train')
test_path = os.path.join(os.getcwd(),'edge_test')

train_dataset = load_dataset(train_path, transforms =  torchvision.transforms.Compose([
                                                        torchvision.transforms.Grayscale(1),
                                                        torchvision.transforms.Resize((60,60)),
                                                        torchvision.transforms.ToTensor()
                                                           ]))

train_loader = data_loader(train_dataset,batch_size=128, shuffle = True)

test_dataset = load_dataset(test_path, transforms =  torchvision.transforms.Compose([
                                                        torchvision.transforms.Grayscale(1),
                                                        torchvision.transforms.Resize((60,60)),
                                                        torchvision.transforms.ToTensor()
                                                           ]))

test_loader = data_loader(test_dataset,batch_size=128 ,shuffle = False)

# Model

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,64, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,128, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(128,256, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Sequential(
            nn.Linear(2*2*512,4096),
            nn.ReLU(),
            nn.Linear(4096,4096),
            nn.ReLU(),
            nn.Linear(4096,520)

        )
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)

        return out


# Training Part

In [None]:
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
pad = nn.ConstantPad2d(2,0)

model = ConvNet()
model.cuda()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9,weight_decay=1e-5)

# Train the model
total_step = len(train_dataset)
train_loss, train_acc = [],[]
train_loss_batch = []

total_category_acc = []

batch_size = 128

for epoch in tqdm_notebook(range(100)):
    start_time = time.time() 
    running_loss, running_corrects = 0.0, 0
    for i,(images,labels) in tqdm_notebook(enumerate(train_loader)):
        images = pad(images).cuda()
        labels = labels.cuda()
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)     
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs.data,1)
        
        batch_loss_total = loss.item() * images.size(0) # total loss of the batch
        acc = preds == labels

        running_loss += batch_loss_total # cumluative sum of loss
        running_corrects += torch.sum(acc) # cumulative sum of correct count
            
        batch_loss = batch_loss_total/len(preds)
        train_loss_batch.append(batch_loss)
    try:
        os.mkdir(os.path.join(os.getcwd(),'Model_Result'))
    except:
        pass
    torch.save(model.state_dict(),os.getcwd()+'/Model_Result/'+str(epoch+1)+'_model.pt')
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)

    print('Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print('-' * 10)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)

# Test Part

In [None]:
path = os.path.join(os.getcwd(),'Model_Result')
result = []
for file_name in sorted(os.listdir(path)):
    if os.path.splitext(file_name)[1] == '.pt':
        file = os.path.join(path,file_name)
        pre_model = torch.load(file)
        print(file)
        model = ConvNet()#.to(param['device'])
        model.cuda()
        model.load_state_dict(pre_model)
        model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
        with torch.no_grad():
            correct = 0
            total = 0
            answer = []
            for images, labels in test_loader:
                images = pad(images)
                images = images.cuda()
                labels = labels.cuda()
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                answer.append(outputs)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            print('Test Accuracy of the model on the test images: {} %'.format(100 * correct / total))
            result.append(100 * correct / total)

# Confusion Matrix

In [None]:
path = os.path.join(os.getcwd(),'Model_Result')
result = []

total_answer = []
total_pred = []

for file_name in sorted(os.listdir(path)):
    if os.path.splitext(file_name)[1] == '.pt':
        category_test_acc = [0] * 520
        file = os.path.join(path,file_name)
        idx = int(file_name.split('_')[0])-1
        pre_model = torch.load(file)
        print(file)
        model = ConvNet()
        model.cuda()
        model.load_state_dict(pre_model)
        model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
        with torch.no_grad():
            correct = 0
            total = 0
            temp = [ [ 0 for _ in range(520) ] for y in range(520)]
            temp_pred = {}
            ten = 0
            count = 0
            pre_label = 0
            for images, labels in test_loader:
                images = pad(images)
                images = images.cuda()
                labels = labels.cuda()
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                acc = predicted == labels
                correct += (predicted == labels).sum().item()
                
                for i in range(len(labels)):
                    if pre_label != labels[i]:
                        pre_label += 1
                        ten = 0
                        count = 0
                    if ten == 0 and count == 0:
                        if labels[i].item() not in temp_pred:
                            temp_pred[str(labels[i].item())] = []
                        temp_pred[str(labels[i].item())].append([str(ten)+'.png',predicted[i].item()])
                        ten += 1
                    elif count == 0:
                        temp_pred[str(labels[i].item())].append([str(ten)+'.png',predicted[i].item()])
                        count += 1
                    else:
                        temp_pred[str(labels[i].item())].append([str(ten)+str(count-1)+'.png',predicted[i].item()])
                        if count == 10:
                            ten += 1
                            count = 0
                        else:
                            count += 1
                            
                    temp[labels[i]][predicted[i]] += 1
            print('Test Accuracy of the model on the test images: {} %'.format(100 * correct / total))
            result.append(100 * correct / total)
            total_answer.append(temp)
            total_pred.append(temp_pred)