In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as tf
from PIL import Image
import torch.autograd as autograd

The images files have the form "ID_2m_0P_xV_yH_z.jpg" where ID is the ID of the person, 2m is fixed, 0P means head pose of 0 degrees (only head pose used in this notebook)
x is the vertical orientation, y is the horizontal orientation and z is either L for left or R for right eye (note that the right eye patch was flipped horizontally).
In training the images are grouped as follows:
For a given person and a given eye (R or L) all orientations are grouped together. One element of the data set is of the form
imgs_r,angles_r,labels,imgs_t,angles_g where imgs_r is considered the "real" image with orientation angles_r, or x_r in the paper,
imgs_t with orientation angles_g is the image of the same person with different orientation (could be the same image since we go through a double loop) and the label is the ID of the person

In [2]:

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dir_path, transform=None):
        super().__init__()
        self.transform = transform
        self.ids=50
        self.data_path = dir_path
        self.file_names = [f for f in os.listdir(self.data_path)
                      if f.endswith('.jpg')]
        self.file_dict = dict()
        for f_name in self.file_names:
            fields = f_name.split('.')[0].split('_')
            identity = fields[0]
            head_pose = fields[2]
            side = fields[-1]
            key = '_'.join([identity, head_pose, side])
            if key not in self.file_dict.keys():
                self.file_dict[key] = []
                self.file_dict[key].append(f_name)
            else:
                self.file_dict[key].append(f_name)
        self.train_images = []
        self.train_angles_r = []
        self.train_labels = []
        self.train_images_t = []
        self.train_angles_g = []

        self.test_images = []
        self.test_angles_r = []
        self.test_labels = []
        self.test_images_t = []
        self.test_angles_g = []
        self.preprocess()
    def preprocess(self):

        for key in self.file_dict.keys():

            if len(self.file_dict[key]) == 1:
                continue

            idx = int(key.split('_')[0])
            flip = 1
            if key.split('_')[-1] == 'R':
                flip = -1

            for f_r in self.file_dict[key]:

                file_path = os.path.join(self.data_path, f_r)

                h_angle_r = flip * float(
                    f_r.split('_')[-2].split('H')[0]) / 15.0
                    
                v_angle_r = float(
                    f_r.split('_')[-3].split('V')[0]) / 10.0
                    

                for f_g in self.file_dict[key]:

                    file_path_t = os.path.join(self.data_path, f_g)

                    h_angle_g = flip * float(
                        f_g.split('_')[-2].split('H')[0]) / 15.0
                        
                    v_angle_g = float(
                        f_g.split('_')[-3].split('V')[0]) / 10.0
                        

                    if idx <= self.ids:
                        self.train_images.append(file_path)
                        self.train_angles_r.append([h_angle_r, v_angle_r])
                        self.train_labels.append(idx - 1)
                        self.train_images_t.append(file_path_t)
                        self.train_angles_g.append([h_angle_g, v_angle_g])
                    else:
                        self.test_images.append(file_path)
                        self.test_angles_r.append([h_angle_r, v_angle_r])
                        self.test_labels.append(idx - 1)
                        self.test_images_t.append(file_path_t)
                        self.test_angles_g.append([h_angle_g, v_angle_g])

    def __getitem__(self, index):
        return (
            self.transform(Image.open(self.train_images[index])),
                torch.tensor(self.train_angles_r[index]),
                self.train_labels[index],
            self.transform(Image.open(self.train_images_t[index])),
                torch.tensor(self.train_angles_g[index]))
        
    def __len__(self):
        return len(self.train_images)
    

In [3]:
transform=tf.Compose([tf.ToTensor(),tf.Resize((64,64),antialias=True)])
# dataset=MyDataset(dir_path='/home/user/Downloads/dataset/0P',transform=transform)
dataset=MyDataset(dir_path='/home/user/Downloads/dataset/0P',transform=transform)

In [4]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [5]:
imgs_r,angles_r,labels,imgs_g,angles_g=next(iter(train_loader))
print(imgs_r.shape,angles_r.shape,labels.shape,imgs_g.shape,angles_g.shape)

torch.Size([32, 3, 64, 64]) torch.Size([32, 2]) torch.Size([32]) torch.Size([32, 3, 64, 64]) torch.Size([32, 2])


In [6]:
device='cuda' if torch.cuda.is_available() else 'cpu'
#device='cpu'

In [7]:
from transformer_net import Discriminator,Generator2
if os.path.isfile('discriminator.pth'):
    discriminator=torch.load('discriminator.pth')
    print('loaded discriminator')
else:
    discriminator=Discriminator()
    print('created discriminator')
if os.path.isfile('generator.pth'):
    generator=torch.load('generator.pth')
    print('loaded generator')
else:
    generator=Generator2()
    print('created generator')


generator=generator.to(device)
discriminator=discriminator.to(device)
LR = 5e-5
beta1=0.5
beta2=0.999
optimizer_g = torch.optim.Adam(generator.parameters(), LR,betas=(beta1, beta2))
optimizer_d = torch.optim.Adam(discriminator.parameters(), LR,betas=(beta1, beta2))

loaded discriminator
loaded generator


In [8]:
from loss_network import LossNetwork
loss_network=LossNetwork()
loss_network=loss_network.to(device)




In [9]:
from loss import content_style_loss,adv_loss_d,adv_loss_g,gaze_loss_d,gaze_loss_g,reconstruction_loss

In [10]:
def generator_step(generator,discriminator,loss_network,imgs_r,imgs_t,angles_r,angles_g):
    optimizer_g.zero_grad()
    generator.train()
    discriminator.eval()
    x_g=generator(imgs_r,angles_g)
    x_recon=generator(x_g,angles_r)
    loss_adv=adv_loss_g(discriminator,imgs_r,x_g)
    loss2=content_style_loss(loss_network,x_g,imgs_t)
    loss_p=loss2[0]+loss2[1]
    loss_gg=gaze_loss_g(discriminator,x_g,angles_g)
    loss_recon=reconstruction_loss(generator,imgs_r,x_recon)
    loss=loss_adv+100*loss_p+5*loss_gg+50*loss_recon
    loss.backward()
    optimizer_g.step()
    return loss.item()

In [11]:
def discriminator_step(generator,discriminator,imgs_r,imgs_t,angles_r,angles_g):
    optimizer_d.zero_grad()
    generator.eval()
    discriminator.train()
    x_g=generator(imgs_r,angles_g)
    loss1=adv_loss_d(discriminator,imgs_r,x_g)
    loss2=gaze_loss_d(discriminator,imgs_r,angles_r)
    loss=loss1+5*loss2
    loss.backward()
    optimizer_d.step()
    return loss.item()

In [12]:
from PIL import Image
import numpy as np
def recover_image(img):
    img=img.cpu().numpy().transpose(0, 2, 3, 1)*255
    return img.astype(np.uint8)
def save_images(imgs, filename):
    height=recover_image(imgs[0])[0].shape[0]
    width=recover_image(imgs[0])[0].shape[1]
    total_width=width*len(imgs)
    
    new_im = Image.new('RGB', (total_width+len(imgs), height))
    for i,img in enumerate(imgs):
        result = Image.fromarray(recover_image(img)[0])
        new_im.paste(result, (i*width+i,0))
    new_im.save(filename)

In [13]:
!mkdir -p debug

/bin/bash: /home/user/anaconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)


In [14]:

epochs=300
for epoch in range(epochs):
    count=0
  
    for imgs_r, angles_r, labels, imgs_t, angles_g in train_loader:
        count+=1
        imgs_r=imgs_r.to(device)
        imgs_t=imgs_t.to(device)
        angles_r=angles_r.to(device)
        angles_g=angles_g.to(device)
        l_d=discriminator_step(generator,discriminator,imgs_r,imgs_t,angles_r,angles_g)
        if count%5==0:
            l_g=generator_step(generator,discriminator,loss_network,imgs_r,imgs_t,angles_r,angles_g)
        if count%1000==0:
            imgs=[imgs_r]
            for h in [-15,-10,-5,0,5,10,15]:
                    a=torch.tile(torch.tensor([h/15.,0.]),[32,1])
                    a=a.to(device)
                    y=generator(imgs_r,a)
                    imgs.append(y.detach())
            save_images(imgs, "./debug/{}_{}.png".format(epoch,count))
    print(l_d,l_g)
    if epoch%20==0:
        torch.save(generator, './generator.pth')
        torch.save(discriminator, './discriminator.pth')
     
    

-0.273295134305954 1.1560956239700317
-0.320864737033844 1.15312659740448
-0.19831483066082 1.2898294925689697
-0.3670003414154053 1.0385856628417969
-0.24964606761932373 1.1227964162826538
-0.23001514375209808 1.0498641729354858
-0.2774949073791504 1.0742840766906738
-0.20838862657546997 1.0117347240447998
-0.2539735734462738 1.2010776996612549
-0.15351393818855286 1.0695288181304932
-0.21429996192455292 0.9391365051269531
-0.16183923184871674 0.948602557182312
-0.23498757183551788 0.9788606762886047
-0.19226963818073273 0.7902793884277344
-0.28566884994506836 0.8925408124923706
-0.229110985994339 1.0464166402816772
-0.2364182472229004 0.8952992558479309
-0.2216944843530655 0.9106601476669312
-0.16782627999782562 0.871482789516449
-0.25914621353149414 0.8641177415847778
-0.15823791921138763 0.8067793250083923
-0.13979174196720123 0.9787031412124634
-0.27305740118026733 0.8142602443695068
-0.24601902067661285 0.838952898979187
-0.10145123302936554 0.825117290019989
-0.13127751648426056

In [15]:
a=torch.tile(torch.tensor([0.,0.]),[32,1])
print(a.size(),angles_r.size())
y=generator(imgs_r.to(device),angles_r.to(device))

torch.Size([32, 2]) torch.Size([4, 2])
