In [None]:
import os
import sys
import time
import datetime
from PIL import Image
import matplotlib.pyplot as plt

import numpy as np
import math
import torch
import torch.nn as nn
from torchvision.utils import save_image
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from torch.utils.data import Dataset

In [None]:
filepath = "/Users/gmmyung/Desktop/Develop/gazetrack" # indicate path of workspace, change the path suited to your Google drive setting
sys.path.append(filepath)

In [None]:
# Check the current filepath and file lists in there
%cd $filepath 
!pwd 
!ls 

In [None]:
n_epoch = 600 # number of training epochs
batch_size = 32
learning_rate = 1e-3
experiment = "exp_colab"  # the name of experiment to save your model and optimizer.

# Make directories for saving outcomes of experiment
os.makedirs(f"{filepath}/{experiment}", exist_ok=True)
os.makedirs(f"{filepath}/results/training", exist_ok=True)
os.makedirs(f"{filepath}/results/validation", exist_ok=True)

In [None]:
class MyDataset(Dataset):
    def __init__(self, root):
        self.root = root 
        self.left_eye_root = os.path.join(self.root, 'data/left_eye')
        self.right_eye_root = os.path.join(self.root, 'data/right_eye')
        self.filenames = []
        self.eye_list = []
        for f in os.listdir(self.left_eye_root):
          if f.endswith('.png'):
            self.filenames.append(int(os.path.basename(os.path.splitext(f)[0])))
        self.filenames.sort()
        print(self.filenames[0:10])

        for filename in self.filenames:
          with open(self.left_eye_root + '/' + str(filename) + '.png', 'rb') as f:
            left_eye_PIL = Image.open(f).convert('L')
            if left_eye_PIL.size != (1, 100, 100):
              newimage = Image.new('L', (100, 100), (0,))
              newimage.paste(left_eye_PIL)
              left_eye_PIL = newimage
          with open(self.right_eye_root + '/' + str(filename) + '.png', 'rb') as f:
            right_eye_PIL = Image.open(f).convert('L')
            if right_eye_PIL.size != (1, 100, 100):
              newimage = Image.new('L', (100, 100), (0,))
              newimage.paste(right_eye_PIL)
              right_eye_PIL = newimage
          self.eye_list += [(left_eye_PIL, right_eye_PIL, str(filename))]

        self.cordData = np.genfromtxt('data/cords.csv', delimiter=',')


    def __getitem__(self, index):
        left_eye, right_eye, filename = self.eye_list[index]

        left_eye = torch.from_numpy(np.array(left_eye)).float().unsqueeze(0)
        right_eye = torch.from_numpy(np.array(right_eye)).float().unsqueeze(0)

        facecords = torch.from_numpy(self.cordData[index][1:137]).float()

        mousecord = torch.from_numpy(self.cordData[index][137:139]).float()

        return left_eye, right_eye, facecords, mousecord, filename

    def __len__(self):
        return len(self.eye_list)

In [None]:
dataset = MyDataset(root=f"/Users/gmmyung/Desktop/Develop/gazetrack")

n_train = math.floor(0.9*len(dataset)) # (default) 90% of the data for training
n_val = len(dataset) - math.floor(0.9*len(dataset)) # (default) 10% of the data for validation

print('Total number of images : {}'.format(len(dataset)))

train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [n_train, n_val])

train_loader = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        drop_last=True
                                        )
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        drop_last=True
                                        )
start = time.time()
left, right, face, mouse, _ = next(iter(train_loader))
print(time.time() - start)
print(left.size(), right.size(), face.size()) # size of the image and mask

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, rate=1, stride=1):
        super().__init__()

        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                      stride=stride, dilation = rate, 
                      padding=rate, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion,
                      kernel_size=3, dilation = rate, 
                      padding=rate, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, 
                                                    kernel_size=1, stride=stride, bias=False),
                                          nn.BatchNorm2d(out_channels * BasicBlock.expansion))

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class eyeNet(nn.Module):

    def __init__(self):
        super().__init__()
        '''
        self.left = nn.Sequential(nn.Conv2d(1, 8, kernel_size=5, padding=2, stride = 2, bias=False),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  BasicBlock(8, 16, rate = 1, stride = 2),
                                  BasicBlock(16, 16, rate = 1, stride = 5),
                                  BasicBlock(16, 32, rate = 1, stride = 5),)
        self.right = nn.Sequential(nn.Conv2d(1, 8, kernel_size=5, padding=2, stride = 2, bias=False),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  BasicBlock(8, 8, rate = 1, stride = 2),
                                  BasicBlock(8, 16, rate = 1, stride = 5),
                                  BasicBlock(16, 32, rate = 1, stride = 5),)
        '''
        self.left = nn.Sequential(nn.Conv2d(1, 8, kernel_size=3, padding=1, stride = 2, bias=False),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  BasicBlock(8, 8, rate = 1, stride = 2),
                                  BasicBlock(8, 16, rate = 1, stride = 5),
                                  BasicBlock(16, 32, rate = 1, stride = 5),)
        self.right = nn.Sequential(nn.Conv2d(1, 8, kernel_size=3, padding=1, stride = 2, bias=False),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  BasicBlock(8, 8, rate = 1, stride = 2),
                                  BasicBlock(8, 16, rate = 1, stride = 5),
                                  BasicBlock(16, 32, rate = 1, stride = 5),)
        self.final = nn.Sequential(nn.Linear(200, 100),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(100, 2),
                                   nn.ReLU(inplace=True))

    def forward(self, l, r, c):
        input_size = l.size()
        l, r = self.left(l), self.left(r)
        x = torch.cat((torch.flatten(l, start_dim=1), torch.flatten(r, start_dim=1), c), 1)
        x = self.final(x)
        return x


l = torch.zeros(16, 1, 100, 100)
r = torch.zeros(16, 1, 100, 100)
c = torch.zeros(16, 136)
model = eyeNet()
start = time.time()
output = model(l, r, c)
print(time.time()-start)
print(output.size())

In [None]:
class eyeNet2(nn.Module):

    def __init__(self):
        super().__init__()
        self.right = nn.Sequential(nn.Conv2d(1, 8, kernel_size=3),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  nn.MaxPool2d(3),
                                  nn.Conv2d(8, 8, kernel_size=3),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  nn.MaxPool2d(3),
                                  nn.Conv2d(8, 8, kernel_size=3),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  nn.MaxPool2d(3))
        self.left = nn.Sequential(nn.Conv2d(1, 8, kernel_size=3),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  nn.MaxPool2d(3),
                                  nn.Conv2d(8, 8, kernel_size=3),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  nn.MaxPool2d(3),
                                  nn.Conv2d(8, 8, kernel_size=3),
                                  nn.BatchNorm2d(8),
                                  nn.ReLU(inplace=True),
                                  nn.MaxPool2d(3))
        self.final = nn.Sequential(nn.Linear(200, 40),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(40, 2),
                                   nn.ReLU(inplace=True))

    def forward(self, l, r, c):
        input_size = l.size()
        l, r = self.left(l), self.left(r)
        x = torch.cat((torch.flatten(l, start_dim=1), torch.flatten(r, start_dim=1), c), 1)
        x = self.final(x)
        return x

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

l = torch.zeros(16, 1, 100, 100)
r = torch.zeros(16, 1, 100, 100)
c = torch.zeros(16, 136)
model = eyeNet2()
start = time.time()
output = model(l, r, c)
print(time.time()-start)
print(output.size())
print(count_parameters(model))

**8. Set-up the training environment (model, loss function, optimizer, etc).**


In [None]:
# You can modify the loss, the model and the optimizer
criterion = nn.MSELoss() # Ignore the index 9 indiciating 'boundaries' for calculating loss
model = eyeNet()
optim = torch.optim.Adam(model.parameters(), lr = learning_rate)

best_model_state, best_optim_state = None, None
latest_model_state, latest_optim_state = None, None

**9. Check your segmentation network.**

You can define your own network in **model_baseline.py.** 

In [None]:
print('The network architecture is as follows.')
print(model)

In [None]:
start = time.time()

for epoch in tqdm(range(n_epoch)):
    train_total_loss, valid_total_loss = 0.0, 0.0
    # # --------------
    # # Training step
    # # --------------
    model.train()
    # At first loading, it may be stuck for a while..
    for _, (l, r, c, m, _) in enumerate(train_loader):
        l = l
        r = r
        c = c
        m = m
        # Predict the pixel-wise probability map indicating class prediction
        pred = model(l, r, c)

        # Calculate loss
        loss = criterion(pred, m)
        
        # Backpropagate the loss to update network's weights
        optim.zero_grad()
        loss.backward()
        optim.step()

        train_total_loss += loss.item() * batch_size / (2019 * 0.9)
        
    # ----------------
    # Validation step
    # ----------------
    with torch.no_grad():
        model.eval()
        # Load mini-batches and do validation
        for _, (l, r, c, m, _) in enumerate(valid_loader):
            le = l
            ri = r
            co = c
            mo = m

            pred = model(le, ri, co)
            loss = criterion(pred, mo)
            valid_total_loss += loss.item() * batch_size / (2019 * 0.1)

        
        # Store the latest and best accuracy models and optimizers
        latest_model_state = model.state_dict()
        latest_optim_state = optim.state_dict()

    # Print training logs and save intermediate validation results during training
    if epoch % 10 == 0 or epoch == n_epoch - 1: 
        print(f"""\n{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} || [{epoch}/{n_epoch}], train_loss = {train_total_loss:.4f}, valid_loss = {valid_total_loss:.4f}""")

elapsed = time.time() - start
print(f"End of training, elapsed time : {elapsed // 60} min {elapsed % 60} sec.")

In [None]:
torch.save(best_model_state, f"{filepath}/{experiment}/best_model_state_dict.pt")
torch.save(best_optim_state, f"{filepath}/{experiment}/best_optim_state_dict.pt")
torch.save(latest_model_state, f"{filepath}/{experiment}/latest_model_state_dict.pt")
torch.save(latest_optim_state, f"{filepath}/{experiment}/latest_optim_state_dict.pt")
print("Successfully saved.")