In [1]:
import os
import cv2
import math
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
from utils import Conv2d
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [2]:
# parameters setting
input_shape = [3,96,96]
num_classes = 7
num_layers = 5
init_lr = 0.0003
device = torch.device("cuda")
batch_size = 16
use_fer=False
use_illumi_correct = False

In [3]:
def illumination_correction(img):
    imhist,bins = np.histogram(img.flatten(),256,normed = True)
    cdf = imhist.cumsum()
    cdf = max(imhist) * cdf / cdf[-1]
    im2 = np.interp(img.flatten(),bins[:-1],cdf)
    im2 = im2.reshape(img.shape)
    return torch.Tensor(im2/255.)

In [4]:
# load data
data_dir = ''

if use_fer:
    full_data = torch.load(data_dir + 'pickles/fer2013.pt')
    full_img, full_emo = full_data['images'][:,None,:,:], full_data['label']

    train_img,train_y = full_img[:28709],full_emo[:28709]
    test_img,test_y = full_img[28709:28709+3589],full_emo[28709:28709+3589]

    train_img = torch.cat((train_img,train_img,train_img),1)
    test_img = torch.cat((test_img,test_img,test_img),1)

else:
#     train_data = torch.load(data_dir + 'pickles/AffectNet_train.pt')
#     train_img, train_y = train_data['images'], train_data['emotion']
    
    train_data = torch.load(data_dir + 'pickles/AffectNet_train_full_p1.pt')
    train_img, train_y = train_data['images'], train_data['labels']

    test_data = torch.load(data_dir + 'pickles/AffectNet_test.pt')
    test_img, test_y = test_data['images'], test_data['emotion']
    
if use_illumi_correct:
    train_img = illumination_correction(train_img)
    test_img = illumination_correction(test_img)

In [5]:
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize([input_shape[1],input_shape[2]]), 
    transforms.RandomCrop([input_shape[1],input_shape[2]]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()#,
#     normalize
])
test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize([input_shape[1],input_shape[2]]),
    transforms.CenterCrop([input_shape[1],input_shape[2]]),
    transforms.ToTensor()#,
#     normalize
])

def default_loader(images):
    img_tensor = train_transform(images)
    return img_tensor
def default_test_loader(images):
    img_tensor = test_transform(images)
    return img_tensor

class Dataset(Dataset):
    def __init__(self, imgs,ys,loader=default_loader):
        self.images = imgs 
        self.label = ys
        self.loader = loader

    def __getitem__(self, index):

        fn = self.images[index]
        img = self.loader(fn)
        label = self.label[index]
        return img,label

    def __len__(self):
        return len(self.images)

In [15]:
from sklearn.metrics import classification_report
import time

class Train_AutoEncoder(nn.Module):
    def __init__(self, num_layers, target_shape, init_lr, adam=True):
        super(Train_AutoEncoder, self).__init__()
        print('AutoEncoder_{}layers'.format(num_layers+1))
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.im_shape = target_shape
        self.dec_im = None
        self.imgs_train = None
        
        self.encoder = nn.Sequential(
            # conv_0
            Conv2d(in_channels=target_shape[0], out_channels=32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            # conv_1
            Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # conv_2
            Conv2d(in_channels=64, out_channels=128, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # conv_3
            Conv2d(in_channels=128, out_channels=256, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            # conv_4
            Conv2d(in_channels=256, out_channels=512, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        
        if num_layers==5:
            self.encoder_ = nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=2, stride=2),
                nn.ReLU(),
                nn.BatchNorm2d(1024)
            )
            
            self.decoder_ = nn.Sequential(
                F.interpolate(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2),
                nn.ReLU(),
                nn.BatchNorm2d(512)
            )
        
        self.decoder = nn.Sequential(
            # D2            
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            # D3
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # D4
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # D5
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),

            Conv2d(in_channels=32, out_channels=target_shape[0],kernel_size=3, stride=1),
            nn.Tanh(),
        )
            
            
        self.optimizer = torch.optim.Adam(params=[p for p in self.parameters() if p.requires_grad], lr=init_lr, eps=1e-5)
        self.loss_fn = nn.MSELoss()
    
        if adam:
            self.optimizer = torch.optim.Adam(params=[p for p in self.parameters() if p.requires_grad], lr=init_lr, eps=1e-5)
        else:
            self.optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
        self.loss_fn = nn.MSELoss() 
    
    def forward(self, net, x):
        batch_size = x.size(0)
        for i in range(len(net)):
            x = net[i](x)
        return x
        
    def train_model(self, data_loader, num_epochs, init_lr, device):
        self.train()
        start=time.time()
        
        total_loss = 0
        
        for epoch in range(num_epochs):
            running_loss = 0.
            num_batches = 0
            total_enc_im = []
            total_dec_im = []
            for count, (images, label) in enumerate(data_loader):
                self.optimizer.zero_grad()
                
                images = images.to(device)
                label = label.to(device)
                
                enc_im = self.forward(self.encoder, images)
                dec_im = self.forward(self.decoder, enc_im)  
                
                loss = self.loss_fn(dec_im, images)
                loss.backward()
                
                self.optimizer.step()
                
                total_enc_im.append(enc_im.detach().cpu())
                total_dec_im.append(dec_im.detach().cpu())
                
                running_loss += loss.detach().cpu().item()
                num_batches += 1
            total_loss = running_loss/num_batches
            elapsed = time.time()-start
            print('\t', 'epoch=',epoch+1, '\t time = {:.0f}m {:.0f}s'.format(elapsed // 60, elapsed % 60), 
                  '\t lr=', init_lr, '\t loss = ', total_loss )
            
            total_dec_im = torch.cat(total_dec_im,0)
            total_enc_im = torch.cat(total_enc_im,0)
#         return total_enc_im, total_dec_im
    
    def test_model(self, data_loader, device):
        self.eval()
        start=time.time()
        
        preds = []
        truth = []
        total_enc_im = []
        total_dec_im = []
        for count, (images, label) in enumerate(data_loader):
            self.optimizer.zero_grad()

            images = images.to(device)

            enc_im = self.forward(self.encoder, images)
            dec_im = self.forward(self.decoder, enc_im)
            
            total_enc_im.append(enc_im.detach().cpu())
            total_dec_im.append(dec_im.detach().cpu())

        total_dec_im = torch.cat(total_dec_im,0)
        total_enc_im = torch.cat(total_enc_im,0)
        return total_enc_im, total_dec_im

In [16]:
train_input  = Dataset(train_img,train_y)
train_loader = DataLoader(train_input, batch_size=batch_size,shuffle=True)

test_input  = Dataset(test_img,test_y,loader=default_test_loader)
test_loader = DataLoader(test_input, batch_size=batch_size,shuffle=False)

model = Train_AutoEncoder(num_layers=4, target_shape=input_shape, init_lr=init_lr)
model.to(device)
model.train_model(data_loader=test_loader, num_epochs=100, init_lr=init_lr, device=device)

AutoEncoder_5layers
	 epoch= 1 	 time = 0m 40s 	 lr= 0.0003 	 loss =  0.08164751309833435
	 epoch= 2 	 time = 1m 23s 	 lr= 0.0003 	 loss =  0.012193172400366497
	 epoch= 3 	 time = 2m 2s 	 lr= 0.0003 	 loss =  0.009853484736304713
	 epoch= 4 	 time = 2m 41s 	 lr= 0.0003 	 loss =  0.00875028188991016
	 epoch= 5 	 time = 3m 21s 	 lr= 0.0003 	 loss =  0.00793762051513138
	 epoch= 6 	 time = 4m 2s 	 lr= 0.0003 	 loss =  0.007318517488754912
	 epoch= 7 	 time = 4m 43s 	 lr= 0.0003 	 loss =  0.006900601780050557
	 epoch= 8 	 time = 5m 23s 	 lr= 0.0003 	 loss =  0.006481693272976434
	 epoch= 9 	 time = 6m 3s 	 lr= 0.0003 	 loss =  0.006181856900130369
	 epoch= 10 	 time = 6m 41s 	 lr= 0.0003 	 loss =  0.005905172094430553
	 epoch= 11 	 time = 7m 21s 	 lr= 0.0003 	 loss =  0.005836370054496342
	 epoch= 12 	 time = 7m 59s 	 lr= 0.0003 	 loss =  0.00550121390558439
	 epoch= 13 	 time = 8m 38s 	 lr= 0.0003 	 loss =  0.005142930071146044
	 epoch= 14 	 time = 9m 19s 	 lr= 0.0003 	 loss =  0.0049128

In [17]:
torch.save(model.state_dict(),data_dir+'BaselinesGray/AutoEncoder_AffectNet_inputsize{}_epo100.pth'.format(96))