In [1]:
import sys; sys.path.insert(0, '..') # add parent folder path where lib folder is
import modules.data

In [7]:
from modules.data import JAFFEDataset
from modules.data import JAFFEDataModule

import numpy as np
import pandas as pd

from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset

import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

from sklearn.model_selection import StratifiedShuffleSplit

from tqdm import tqdm, trange
import matplotlib.pyplot as plt

from os.path import basename

class DeepEmotion(pl.LightningModule):
    def __init__(self):
        '''
        Deep_Emotion class contains the network architecture.
        '''
        super().__init__()
        
        self.conv1 = nn.Conv2d(1,10,3)
        self.conv2 = nn.Conv2d(10,10,3)
        self.pool2 = nn.MaxPool2d(2,2)

        self.conv3 = nn.Conv2d(10,10,3)
        self.conv4 = nn.Conv2d(10,10,3)
        self.pool4 = nn.MaxPool2d(2,2)

        self.norm = nn.BatchNorm2d(10)

        self.fc1 = nn.Linear(810,50)
        self.fc2 = nn.Linear(50,7)

        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        self.fc_loc = nn.Sequential(
            nn.Linear(640, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
    
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 640)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x
        return xs

    def forward(self,input):
        out = self.stn(input)

        out = F.relu(self.conv1(out))
        out = self.conv2(out)
        out = F.relu(self.pool2(out))

        out = F.relu(self.conv3(out))
        out = self.norm(self.conv4(out))
        out = F.relu(self.pool4(out))

        out = F.dropout(out)
        out = out.view(-1, 810)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)

        return out
    
    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        image, desc = batch
        label = desc['exp']
        
        print(image.size())
        predictions = self(image)

        print(predictions)
        
        
        loss = nn.functional.cross_entropy(predictions, label)

        result = pl.TrainResult(loss)
        result.log('train_loss', loss)
        return result
        
        result =  pl.TrainResult()
        result.log('smth', 0)
        return result

    def test_step(self, batch, batch_idx):
        image, desc = batch
        label = desc['exp']
        
        predictions = self(image)

        loss = nn.functional.cross_entropy(predictions, label)

        label_hat = predictions.argmax(dim=1).flatten()
        
        accuracy = FM.accuracy(label, label_hat, num_classes=2)

        result = pl.EvalResult(checkpoint_on=loss)

        result.batch_acc = accuracy
        result.batch_len = label.shape[0]
        result.batch_loss = loss

        return result

    def test_epoch_end(self, outputs):
        all_accs = outputs.batch_acc
        all_loss = outputs.batch_loss
        all_lens = outputs.batch_len
        all_lens = torch.tensor(all_lens, dtype=torch.float, device=self.device)

        epoch_acc = torch.dot(all_accs, all_lens) / all_lens.sum()
        epoch_loss = torch.dot(all_loss, all_lens) / all_lens.sum()

        result = pl.EvalResult(checkpoint_on=epoch_loss)

        result.log('test_acc' , epoch_acc , on_step=False, on_epoch=True)
        result.log('test_loss', epoch_loss, on_step=False, on_epoch=True)

        return result

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    
pl.seed_everything(42)

jaffe = JAFFEDataModule()
jaffe.setup()


deep_emotion = DeepEmotion()

# trainer = pl.Trainer(gpus=-1, max_epochs=6)
trainer = pl.Trainer( max_epochs=6)
trainer.fit(deep_emotion, jaffe)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

   | Name         | Type        | Params
----------------------------------------------
0  | conv1        | Conv2d      | 100   
1  | conv2        | Conv2d      | 910   
2  | pool2        | MaxPool2d   | 0     
3  | conv3        | Conv2d      | 910   
4  | conv4        | Conv2d      | 910   
5  | pool4        | MaxPool2d   | 0     
6  | norm         | BatchNorm2d | 20    
7  | fc1          | Linear      | 40 K  
8  | fc2          | Linear      | 357   
9  | localization | Sequential  | 2 K   
10 | fc_loc       | Sequential  | 20 K  


JAFFE data found
Epoch 0:   0%|          | 0/6 [00:00<?, ?it/s] torch.Size([32, 1, 48, 48])
tensor([[ 3.7465e-01, -3.8261e-02,  3.0418e-01, -8.2634e-01, -8.8386e-01,
          6.1403e-01, -9.5638e-01],
        [ 2.6466e-01, -2.7815e-01, -2.1599e-01,  6.9167e-02, -4.8409e-01,
          3.5739e-01, -4.8492e-01],
        [ 3.7114e-01, -1.9240e-01, -4.6527e-02, -3.7423e-01, -3.2346e-01,
          1.7331e-01, -7.1305e-01],
        [ 1.3894e-01, -2.7555e-01, -1.5183e-01, -1.4213e-02, -2.6891e-01,
          6.7526e-01, -9.1423e-01],
        [ 2.8601e-01, -2.5629e-01, -4.2176e-01, -1.8232e-01, -8.7935e-02,
          5.4570e-01, -6.2314e-01],
        [-1.0005e-01, -2.8095e-01, -1.5021e-01, -2.3510e-01, -2.9435e-01,
          4.4562e-01, -9.4803e-01],
        [ 8.5735e-02,  3.2567e-01,  2.8544e-01, -5.7256e-01, -2.2005e-01,
          7.2381e-01, -3.7604e-01],
        [ 2.3748e-01, -2.3519e-01,  1.3910e-01, -4.4526e-01, -1.4005e-01,
          4.7511e-01, -6.2438e-01],
        [-4.6546e-01,  2.380



Epoch 0:  33%|███▎      | 2/6 [00:02<00:05,  1.41s/it, loss=1.989, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[ 1.1619e-01,  3.2434e-01, -8.1933e-02,  5.2304e-02, -4.4306e-01,
         -5.8329e-02, -4.6302e-01],
        [ 2.3860e-01,  4.0682e-01, -2.2400e-01,  1.9629e-01, -3.7243e-02,
         -8.2671e-02, -4.6480e-01],
        [ 1.1040e-01,  4.2322e-01, -1.6137e-01,  7.8457e-02, -6.3552e-01,
          5.6188e-02, -7.2235e-01],
        [-5.3912e-02,  1.4784e-01, -1.5216e-01,  1.4311e-01, -2.3168e-01,
          2.8845e-01, -6.1145e-01],
        [-1.5966e-01,  2.7682e-01, -3.3282e-01,  2.2026e-01, -6.5388e-01,
          2.8040e-01, -1.1245e+00],
        [ 3.0053e-01,  1.7280e-01, -3.2391e-01,  1.2279e-01, -6.8574e-02,
          1.8219e-02, -6.7592e-01],
        [-5.8858e-01,  3.3442e-01, -4.2318e-01, -3.4517e-01, -5.4871e-01,
          4.3372e-01, -2.2724e-01],
        [-2.7483e-01,  1.3866e-01, -3.3950e-01, -9.7888e-02, -3.8625e-01,
          4.6467e-01, -7.4476e-01],
        [ 5.2409

Epoch 0:  83%|████████▎ | 5/6 [00:03<00:00,  1.60it/s, loss=1.997, v_num=47]torch.Size([31, 1, 48, 48])
tensor([[-0.2680,  0.1749, -0.3954, -0.0553,  0.0057,  0.0114, -0.5902],
        [-0.2030,  0.0199, -0.4316,  0.0943, -0.3083, -0.0431, -0.2308],
        [ 0.0444, -0.2476, -0.5701,  0.2004,  0.0549, -0.0326, -0.5490],
        [-0.0356,  0.0696, -0.6828,  0.0577, -0.0221,  0.0283, -0.5032],
        [ 0.0011, -0.5025, -0.5552, -0.0388,  0.0827, -0.1815, -0.5084],
        [-0.2075,  0.0435, -0.3256,  0.1410,  0.1401, -0.1340, -0.4288],
        [-0.1245, -0.0205, -0.2166,  0.1164, -0.1209, -0.1212, -0.0635],
        [ 0.0844, -0.1530, -0.5309,  0.2162, -0.0184, -0.4066, -0.4185],
        [ 0.0021, -0.0972, -0.1769, -0.0528,  0.0977, -0.1349, -0.2139],
        [-0.2616, -0.0372, -0.2152, -0.1642, -0.1444, -0.1783, -0.3743],
        [-0.2124,  0.1371, -0.2132, -0.1336, -0.0474, -0.1557, -0.3444],
        [ 0.0455, -0.2762, -0.5112,  0.1358, -0.0056,  0.0392, -0.4627],
        [-0.0865,  0

Epoch 1:  50%|█████     | 3/6 [00:02<00:02,  1.06it/s, loss=1.971, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[-1.2162e-01, -5.8349e-02, -3.5295e-01,  3.8567e-01, -3.4761e-02,
         -3.4437e-01,  1.7392e-01],
        [-1.5049e-01, -2.5839e-01, -2.0917e-01,  2.3548e-01,  2.3349e-01,
         -2.1941e-01,  7.1903e-02],
        [-2.1732e-01, -4.0710e-02, -2.0458e-01, -1.1923e-01, -5.3214e-02,
         -3.8083e-01,  8.9747e-02],
        [-4.1557e-03, -6.6528e-02, -2.4533e-01,  1.0624e-01,  1.0407e-02,
         -3.8262e-01,  1.8575e-01],
        [-9.3395e-02, -4.7028e-01, -5.3731e-01,  1.6065e-01,  4.7645e-01,
         -4.4005e-01, -3.2167e-01],
        [-2.4213e-01,  1.5327e-01, -1.2314e-01,  1.9072e-02,  1.3953e-01,
         -3.7847e-01, -3.4664e-02],
        [-3.0875e-01, -5.5004e-01, -4.8802e-01,  1.8106e-01,  6.1432e-02,
         -1.5334e-01, -1.1198e-02],
        [-8.1659e-02, -2.0594e-01, -1.3695e-02,  2.6060e-02,  2.1416e-01,
         -1.3413e-01, -4.0696e-01],
        [-1.4175

Epoch 2:   0%|          | 0/6 [00:00<?, ?it/s, loss=1.959, v_num=47]        torch.Size([32, 1, 48, 48])
tensor([[-0.1910, -0.2470, -0.3537, -0.0921, -0.0138, -0.2880,  0.0889],
        [-0.1349, -0.0346, -0.2352, -0.0533, -0.1214, -0.3896, -0.0225],
        [-0.2944, -0.1311, -0.3239, -0.0358, -0.1207, -0.2682, -0.0323],
        [-0.2076,  0.1710, -0.0557, -0.2694, -0.0496, -0.5724,  0.3838],
        [-0.1329, -0.0369, -0.1624, -0.2664,  0.0160, -0.3607, -0.0449],
        [ 0.0086, -0.0876, -0.2358, -0.1599,  0.0667, -0.3169,  0.3020],
        [-0.2722, -0.0699, -0.3658,  0.1876,  0.1947, -0.4470,  0.0854],
        [-0.0650, -0.6356, -0.3191, -0.0786,  0.2614, -0.1815, -0.2158],
        [-0.1494, -0.2026, -0.4036, -0.0326,  0.1401, -0.4306,  0.2291],
        [-0.1545, -0.2296, -0.2253, -0.0919,  0.1953, -0.3597, -0.4500],
        [-0.2621,  0.0125,  0.0281,  0.1271,  0.0632, -0.3205,  0.1098],
        [ 0.0243, -0.0933, -0.4508,  0.0609, -0.1926, -0.4211,  0.2939],
        [-0.3645, -0

Epoch 2:  67%|██████▋   | 4/6 [00:03<00:01,  1.33it/s, loss=1.937, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[-1.4468e-01, -3.3269e-01, -4.1871e-01, -2.9435e-01,  9.8291e-02,
         -4.1241e-01, -4.3655e-01],
        [-2.5898e-01,  1.7698e-02, -1.0311e-01, -2.7475e-01, -1.1114e-01,
         -5.8002e-01,  2.8943e-01],
        [-3.4193e-01,  1.5409e-02,  1.3831e-01, -5.0217e-01,  1.1654e-01,
         -2.3796e-02, -1.3840e-01],
        [-1.9081e-01, -4.8674e-01, -3.8764e-01, -1.9639e-01, -1.5167e-02,
         -2.2312e-01,  2.5628e-03],
        [-2.5378e-01,  1.1450e-01, -5.5564e-02, -3.6178e-01, -4.2400e-03,
         -4.3979e-01, -2.2362e-01],
        [-3.3767e-01, -6.3005e-02,  2.4137e-01, -2.8770e-01,  8.3842e-02,
         -4.1795e-01, -4.0412e-02],
        [-2.7802e-01, -2.4591e-01, -2.6939e-01, -2.3950e-02,  1.1521e-02,
         -4.7088e-01,  5.6451e-02],
        [-1.5431e-01, -3.5832e-02, -1.5928e-01,  1.8263e-02,  7.7179e-02,
         -4.3395e-01, -7.5884e-02],
        [-2.1133

Epoch 3:  17%|█▋        | 1/6 [00:02<00:13,  2.79s/it, loss=1.922, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[-7.2811e-02, -1.5585e-01, -2.5836e-01,  4.0794e-01,  6.7052e-02,
         -4.2699e-01, -3.1921e-01],
        [-2.2190e-01, -1.9997e-01, -6.6059e-01, -4.3415e-01, -1.2055e-01,
         -7.9040e-01, -2.9768e-01],
        [-1.4019e-01, -3.6275e-01, -2.1661e-01, -2.5767e-01, -5.2825e-02,
         -4.0684e-01, -2.3106e-01],
        [-1.1002e-01, -5.8709e-02, -1.7477e-01, -1.1816e-01,  2.7913e-02,
         -2.3459e-01, -1.5437e-01],
        [-1.9524e-01,  5.4892e-04, -9.8265e-02, -2.0195e-01, -9.4162e-02,
         -8.6788e-02, -7.0251e-02],
        [-1.2910e-01, -4.1877e-01, -5.9799e-01, -1.9216e-01, -5.5129e-02,
         -2.1341e-01, -3.4433e-01],
        [-3.4389e-01, -1.5659e-01,  1.5129e-01, -6.0020e-01,  4.2968e-01,
         -1.3606e-01, -6.9029e-02],
        [-2.4623e-01, -9.8487e-02,  7.0824e-02, -1.5846e-01,  4.4013e-02,
         -4.3598e-01, -4.6302e-01],
        [-3.2047

Epoch 3:  67%|██████▋   | 4/6 [00:03<00:01,  1.29it/s, loss=1.898, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[-1.4270e-01, -3.1790e-01, -2.0463e-01, -2.1245e-01, -3.6380e-01,
         -1.0734e-01, -4.1984e-01],
        [-6.2251e-01,  2.3956e-01,  1.6342e-01, -4.2744e-01, -1.8882e-01,
         -3.1542e-01, -5.1812e-01],
        [-1.2449e-01, -1.9345e-01, -3.0178e-01,  2.2648e-01, -1.8988e-01,
         -3.9411e-01, -4.9926e-01],
        [-2.3574e-01, -1.0489e-01, -1.5909e-01,  2.2228e-01, -9.4494e-02,
         -3.8224e-01, -4.7162e-01],
        [-1.0605e-01, -1.1681e-01, -4.3854e-01,  3.8870e-02, -2.7269e-01,
         -4.2073e-01, -2.8533e-01],
        [-3.8938e-01, -1.5999e-01, -3.1090e-02, -8.9504e-02, -2.2978e-01,
         -1.3540e-01, -2.6384e-01],
        [ 5.5042e-02, -3.6507e-01, -6.0192e-01, -2.8575e-02, -2.0819e-01,
         -5.6641e-01, -4.5993e-01],
        [-3.7934e-01, -1.9381e-01, -1.6952e-01, -1.7758e-01,  1.3152e-01,
          1.2802e-01, -1.3165e-01],
        [-6.8525

Epoch 4:  17%|█▋        | 1/6 [00:02<00:14,  2.87s/it, loss=1.870, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[ 5.3976e-02, -4.2707e-01, -7.0094e-01, -4.2599e-02, -1.8888e-01,
         -6.5507e-01, -3.7227e-01],
        [-1.2219e-01, -1.7352e-01, -8.7092e-02, -1.6563e-01, -2.2437e-01,
          1.6672e-01, -6.3429e-01],
        [-3.5735e-01, -2.3253e-01, -3.9865e-01, -5.2862e-01, -5.9265e-02,
          4.5226e-02, -4.3416e-01],
        [ 6.6202e-02, -3.9454e-01, -3.7993e-01, -1.2131e-01, -4.6010e-01,
         -1.6707e-01, -5.8355e-01],
        [ 2.4762e-01, -6.0084e-01, -8.4495e-01,  1.3495e-01, -6.1502e-01,
         -4.8669e-01, -1.3901e-01],
        [-1.7338e-01, -5.4258e-01, -2.3622e-01, -2.3283e-01, -2.1341e-01,
          9.5584e-02, -5.5731e-01],
        [-2.3889e-01,  3.1586e-01,  1.6317e-01, -5.3818e-01, -3.7781e-01,
         -2.0861e-01, -4.7249e-01],
        [ 1.8036e-01, -6.8777e-01, -7.0870e-01,  4.0366e-01, -3.1360e-01,
         -5.3580e-01, -2.9287e-02],
        [-8.7629

Epoch 4:  67%|██████▋   | 4/6 [00:03<00:01,  1.25it/s, loss=1.837, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[-0.0137, -0.5049, -0.5399,  0.0049, -0.1739,  0.0208, -0.7575],
        [-0.6224,  0.3803,  0.2292, -1.0865, -0.6327,  0.0934, -0.3468],
        [-0.2604, -1.0483, -0.4071, -0.6125,  0.1799,  0.4958,  0.1826],
        [-0.2007, -0.6430, -0.6195, -0.4979, -0.4579, -0.4228,  0.4136],
        [-0.0024, -0.3149, -0.2295,  0.1764, -0.3305, -0.3112, -0.5370],
        [-0.0328, -0.0565, -0.4970,  0.2582, -0.6183, -0.1970, -0.2729],
        [ 0.3354, -0.6386, -0.8931, -0.1970, -0.3123, -0.2048, -1.0634],
        [ 0.4752, -0.2277, -0.7575, -0.1460, -0.5473, -0.4437, -0.6638],
        [ 0.1972, -0.5723, -0.6877,  0.1188, -0.5723, -0.6108, -0.2741],
        [ 0.0905, -0.9181, -0.7371, -0.0534, -0.2582,  0.0798, -0.5208],
        [-0.3132,  0.4900, -0.0884, -0.1599, -0.7570, -0.3056, -0.5612],
        [ 0.3604, -0.3868, -0.4420, -0.0148, -0.3895, -0.1970, -0.5816],
        [-0.3477,  0

Epoch 5:  17%|█▋        | 1/6 [00:02<00:14,  2.97s/it, loss=1.794, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[ 2.4763e-01, -7.5513e-01, -4.6202e-01,  9.2283e-02, -5.7973e-02,
         -1.8898e-01, -2.7102e-01],
        [-2.8181e-01, -1.1622e-01, -2.7704e-02,  2.5785e-01, -3.6560e-01,
         -2.9577e-01, -2.8432e-01],
        [ 4.0734e-01, -7.7198e-01, -9.6022e-01,  7.4033e-01, -7.0919e-01,
         -5.5492e-01, -9.1077e-02],
        [-3.1841e-01,  3.9397e-01, -4.2739e-02, -1.4418e-01, -5.0960e-01,
         -5.0805e-01, -2.1039e-01],
        [-4.9125e-01, -7.3075e-01, -7.4375e-01,  2.8872e-02, -1.0528e+00,
         -1.8333e-01,  3.8230e-01],
        [-5.5551e-01,  7.6569e-01,  2.1506e-01, -1.7425e-01, -8.1087e-01,
         -6.2463e-02, -1.4410e-01],
        [-4.1531e-01,  2.1086e-01, -1.0051e-01, -7.1241e-01, -1.5462e-01,
         -5.1583e-02, -6.0582e-01],
        [-2.5185e-01, -3.2476e-01, -2.5915e-01, -2.1360e-01, -9.0054e-02,
         -3.3389e-01,  2.2575e-02],
        [ 9.0276

Epoch 5:  67%|██████▋   | 4/6 [00:03<00:01,  1.20it/s, loss=1.766, v_num=47]torch.Size([32, 1, 48, 48])
tensor([[ 1.4875e-01, -8.6152e-01, -5.2593e-01, -3.4779e-01,  4.1319e-01,
         -2.2074e-01,  2.4633e-01],
        [-5.8956e-01,  4.1724e-01, -1.6410e-01, -1.3480e-01, -8.4639e-01,
         -4.3515e-01, -1.0750e-01],
        [-3.9666e-01, -1.0339e+00, -7.3504e-01, -5.2087e-01, -4.2644e-01,
         -3.8224e-01,  1.6596e-01],
        [-2.3378e-01,  6.3867e-01,  2.6516e-01, -1.3691e+00, -4.8629e-01,
         -1.4888e-01, -7.6095e-02],
        [-1.6826e-01,  3.6315e-01, -2.1846e-01, -1.8285e-01, -3.9796e-01,
         -1.2560e-01, -1.7621e-01],
        [-4.2340e-01, -6.8913e-01, -1.2139e+00, -6.6155e-01, -1.7356e-01,
         -4.5287e-01,  1.4422e-01],
        [-3.5260e-01,  5.0503e-01,  7.1348e-01, -1.3341e+00, -4.3156e-01,
         -2.5354e-01, -1.0395e+00],
        [-3.4902e-01, -3.6472e-01, -4.8274e-01, -4.0994e-01, -1.9294e-01,
         -7.7893e-02,  2.2110e-01],
        [-6.3097

Saving latest checkpoint..


Epoch 5: 100%|██████████| 6/6 [00:04<00:00,  1.34it/s, loss=1.749, v_num=47]


1

In [13]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir ../modules/lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 32821), started 0:26:56 ago. (Use '!kill 32821' to kill it.)