<a href="https://colab.research.google.com/github/kjinb1212/Falldown-detection-KD/blob/main/KD_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import train_test_split

from glob import glob
import os
import sys
from PIL import Image
import timm
import pandas as pd

from efficientnet_pytorch import EfficientNet



# KD train

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), 
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
                             F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

def get_teacher_output(model, loader):
    model.eval()
    output = []
    with torch.no_grad():
        for data, _ in loader:
            data = data.to(device)
            output.append(model(data))
    torch.cuda.empty_cache()
    return output

In [None]:
def train_kd(model, teacher_output, train_loader, test_loader, criterion, 
             optimizer, epochs, T, alpha, save_name):
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                          patience=4, verbose=True)
    best_loss = None
    best_acc = None
    patience = 0

    history = {'loss': [], 'acc': []}
    
    for epoch in range(epochs):
        print("--------- epoch : {} ------------".format(epoch+1))
        model.train()
        train_losses = []
        for i, (data, label) in enumerate(train_loader):
            data = data.to(device)
            label = label.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn_kd(output, label, teacher_output[i], T, alpha)
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
            
        train_loss = np.average(train_losses)
        print("train loss: {}".format(train_loss))
        
        
        model.eval()
        test_losses = []
        correct = 0
        total = 0
        with torch.no_grad():
            for i, (data, label) in enumerate(test_loader):
                data = data.to(device)
                label = label.to(device)

                output = model(data)
                loss = criterion(output, label)
                test_losses.append(loss.item())
                _, predict = torch.max(output.data, 1)
                correct += (predict == label).sum().item()
                total += label.size(0)
                
        test_loss = np.average(test_losses)
        test_acc = 100 * correct / total
        print("test loss: {}, \t test acc: {}%".format(test_loss, test_acc))

        history['loss'].append(test_loss)
        history['acc'].append(test_acc)
        
        if (best_loss is None) or (best_loss > test_loss):
            best_loss = test_loss
            best_acc = test_acc
            torch.save(model.state_dict(), 'new_model_weights/'+ save_name +'.pth')
            print('Best loss: {}\n'.format(best_loss))
            patience = 0
        else:
            patience += 1
        
        if patience > 7:
            print("early stop at {} epoch".format(epoch + 1))
            break
            
        scheduler.step(metrics=test_loss)
   
    print("best loss: {},\t best acc: {}%\n\n".format(best_loss, best_acc))
    return best_loss, best_acc, history        
    

# Custom Datasaet


In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, input_size, train = True, padding = True, normalize = False,
                 bright_ness = 0.2, hue = 0.15, contrast = 0.15, random_Hflip = 0.3, rotate_deg = 20):
        orig_normal_path = glob(os.path.join(root_dir, 'normal') + '/*.jpg')
        orig_fall_path = glob(os.path.join(root_dir, 'falldown') + '/*.jpg')
        orig_back_path = glob(os.path.join(root_dir, 'background') + '/*.jpg')
        
        normal_paths = []
        fall_paths = []
        back_paths = []
        
        for path in orig_normal_path:
            img = Image.open(path)
            if min(img.size[0], img.size[1]) < 32:
                pass
            else:
                normal_paths.append(path)
                
        for path in orig_fall_path:
            img = Image.open(path)
            if min(img.size[0], img.size[1]) < 32:
                pass
            else:
                fall_paths.append(path)
                
        for path in orig_back_path:
            img = Image.open(path)
            if min(img.size[0], img.size[1]) < 32:
                pass
            else:
                back_paths.append(path)
                
        self.total_paths = normal_paths + fall_paths + back_paths
        self.labels = [0] * len(normal_paths) + [1] * len(fall_paths) + [2] * len(back_paths)
        
        transform = []
        if train:
            #transform.append(torchvision.transforms.ColorJitter(brightness=bright_ness, hue=hue, contrast=contrast))
            transform.append(torchvision.transforms.RandomHorizontalFlip(p=random_Hflip))
            #transform.append(torchvision.transforms.RandomCrop(224))
            transform.append(torchvision.transforms.RandomRotation(degrees=rotate_deg))
        transform.append(torchvision.transforms.ToTensor())
        if normalize:
            transform.append(torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        if padding:
            transform.append(lambda x: torchvision.transforms.Pad(((128 - x.shape[2]) // 2, (128 - x.shape[1]) // 2), fill=0,
                                                     padding_mode="constant")(x))
        transform.append(torchvision.transforms.Resize((input_size, input_size)))
        self.transform = torchvision.transforms.Compose(transform)
        
        
    def __len__(self):
        return len(self.total_paths)

    def __getitem__(self, index):
        img = Image.open(self.total_paths[index])
        img = self.transform(img)
        return img, self.labels[index]

# student model 

In [None]:
class CNN_layers(nn.Module):
    def __init__(self):
        super(CNN_layers, self).__init__()      
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.conv4 = nn.Conv2d(64, 32, 3)
        #self.conv5 = nn.Conv2d(32, 16, 3)

        self.fc1 = nn.Linear(32 * 8 * 8, 16)
        self.fc2 = nn.Linear(16, 8)
        self.fc3 = nn.Linear(8, 3)

        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(32)
        #self.bn5 = nn.BatchNorm2d(16)

        self.bn6 = nn.BatchNorm1d(16)
        self.bn7 = nn.BatchNorm1d(8)
        self.padding = nn.ZeroPad2d(1)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(self.padding(x))))) # 128 -> 64
        x = self.pool(F.relu(self.bn2(self.conv2(self.padding(x))))) # 64 -> 32
        x = self.pool(F.relu(self.bn3(self.conv3(self.padding(x))))) # 32 -> 16
        x = self.pool(F.relu(self.bn4(self.conv4(self.padding(x))))) # 16 -> 8
        #x = self.pool(F.relu(self.bn5(self.conv5(self.padding(x))))) # 8 -> 4

        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.bn6(self.fc1(x)))
        x = F.relu(self.bn7(self.fc2(x)))
        x = self.fc3(x)
        return x

# create data loader 

In [None]:
INPUT_SIZE = 128
PADDING = False
NORMALIZE = False
BATCHSIZE = 128
NUMEPOCH = 100

train_data = CustomDataset(
    root_dir='train',
    input_size=INPUT_SIZE, train=True, padding=PADDING, normalize=NORMALIZE,
    bright_ness=0, hue=01.5, contrast=0.15, random_Hflip=0, rotate_deg=0)

test_data = CustomDataset(
    root_dir='validation',
    input_size=INPUT_SIZE, train=False, padding=PADDING, normalize=NORMALIZE,
    bright_ness=0, hue=01.5, contrast=0.15, random_Hflip=0, rotate_deg=0)

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True, num_workers=32, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCHSIZE, shuffle=False, num_workers=32, drop_last=True)

# train student model by KD

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
teachers = ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b4']

logs = []
for teacher_name in teachers:
    torch.cuda.empty_cache()
    teacher_model = EfficientNet.from_pretrained(teacher_name, num_classes=3).to(device)
    teacher_model.load_state_dict(torch.load('new_model_weights/'+ teacher_name + '.pth'))

    teacher_output = get_teacher_output(teacher_model, train_loader)
    
    torch.cuda.empty_cache()
    student_model = CNN_layers().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(student_model.parameters(), weight_decay=1e-4, lr=0.001)
    save_name = 'kd_' + teacher_name

    T = [2, 3, 4, 5, 6 ,7]
    ALPHA = [0.001, 0.01, 0.1, 0.5, 0.9]
    

    for t in T:
        for j, alpha in enumerate(ALPHA):
            loss, acc, history = train_kd(model=student_model,teacher_output=teacher_output,
                                  train_loader=train_loader, test_loader=test_loader, 
                                  criterion = criterion, optimizer=optimizer, 
                                  epochs=NUMEPOCH, T=t, alpha=alpha, save_name=save_name) 

            s = teacher_name +'_T{}_al{}\tloss = {}, \tacc = {}'.format(t, j, loss, acc)
            logs.append(s)
            df = pd.DataFrame(history)
            his_name = save_name + '_T{}_al{}'.format(t, alpha)
            df.to_csv("new_history/"+ his_name+ "_history.csv", mode='w')

Loaded pretrained weights for efficientnet-b0
--------- epoch : 1 ------------
train loss: 0.6361562737628169
test loss: 0.5609208233654499, 	 test acc: 76.26953125%
Best loss: 0.5609208233654499

--------- epoch : 2 ------------
train loss: 0.2607366149516209
test loss: 0.6808250192552805, 	 test acc: 73.14453125%
--------- epoch : 3 ------------
train loss: 0.17448783892652261
test loss: 0.35910811088979244, 	 test acc: 86.669921875%
Best loss: 0.35910811088979244

--------- epoch : 4 ------------
train loss: 0.13868358729245223
test loss: 0.9095166698098183, 	 test acc: 64.6484375%
--------- epoch : 5 ------------
train loss: 0.11426037527701777
test loss: 0.4371277401223779, 	 test acc: 83.349609375%
--------- epoch : 6 ------------
train loss: 0.10253044233009544
test loss: 0.5280382707715034, 	 test acc: 81.103515625%
--------- epoch : 7 ------------
train loss: 0.09270014061147104
test loss: 0.501621063798666, 	 test acc: 81.15234375%
--------- epoch : 8 ------------
train loss:

KeyboardInterrupt: 

In [None]:
#ALPHA index [0.001, 0.01, 0.1, 0.5, 0.9]

for log in logs:
    print(log)

efficientnet-b0_T2_al0	loss = 0.35910811088979244, 	acc = 86.669921875
efficientnet-b0_T2_al1	loss = 0.38970668613910675, 	acc = 87.158203125
efficientnet-b0_T2_al2	loss = 0.392495047301054, 	acc = 87.59765625
efficientnet-b0_T2_al3	loss = 0.39769905991852283, 	acc = 87.40234375
efficientnet-b0_T2_al4	loss = 0.4000811204314232, 	acc = 86.81640625
efficientnet-b0_T3_al0	loss = 0.39693687204271555, 	acc = 87.01171875
efficientnet-b0_T3_al1	loss = 0.4000391745939851, 	acc = 87.255859375
efficientnet-b0_T3_al2	loss = 0.4001190410926938, 	acc = 87.060546875
efficientnet-b0_T3_al3	loss = 0.3963844859972596, 	acc = 87.353515625
efficientnet-b0_T3_al4	loss = 0.39185390807688236, 	acc = 87.353515625
efficientnet-b0_T4_al0	loss = 0.4020005688071251, 	acc = 86.865234375
efficientnet-b0_T4_al1	loss = 0.39988416340202093, 	acc = 87.20703125
efficientnet-b0_T4_al2	loss = 0.39767225086688995, 	acc = 87.060546875
efficientnet-b0_T4_al3	loss = 0.397953636944294, 	acc = 87.20703125
efficientnet-b0_T4_al