In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as tf
from PIL import Image
import torch.autograd as autograd
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dir_path, transform=None):
        super().__init__()
        self.transform = transform
        self.ids=50
        self.data_path = dir_path
        self.file_names = [f for f in os.listdir(self.data_path)
                      if f.endswith('.jpg')]
        self.file_dict = dict()
        for f_name in self.file_names:
            fields = f_name.split('.')[0].split('_')
            identity = fields[0]
            head_pose = fields[2]
            side = fields[-1]
            key = '_'.join([identity, head_pose, side])
            if key not in self.file_dict.keys():
                self.file_dict[key] = []
                self.file_dict[key].append(f_name)
            else:
                self.file_dict[key].append(f_name)
        self.train_images = []
        self.train_angles_r = []
        self.train_labels = []
        self.train_images_t = []
        self.train_angles_g = []

        self.test_images = []
        self.test_angles_r = []
        self.test_labels = []
        self.test_images_t = []
        self.test_angles_g = []
        self.preprocess()
    def preprocess(self):

        for key in self.file_dict.keys():

            if len(self.file_dict[key]) == 1:
                continue

            idx = int(key.split('_')[0])
            flip = 1
            if key.split('_')[-1] == 'R':
                flip = -1

            for f_r in self.file_dict[key]:

                file_path = os.path.join(self.data_path, f_r)

                h_angle_r = flip * float(
                    f_r.split('_')[-2].split('H')[0]) / 15.0
                    
                v_angle_r = float(
                    f_r.split('_')[-3].split('V')[0]) / 10.0
                    

                for f_g in self.file_dict[key]:

                    file_path_t = os.path.join(self.data_path, f_g)

                    h_angle_g = flip * float(
                        f_g.split('_')[-2].split('H')[0]) / 15.0
                        
                    v_angle_g = float(
                        f_g.split('_')[-3].split('V')[0]) / 10.0
                        

                    if idx <= self.ids:
                        self.train_images.append(file_path)
                        self.train_angles_r.append([h_angle_r, v_angle_r])
                        self.train_labels.append(idx - 1)
                        self.train_images_t.append(file_path_t)
                        self.train_angles_g.append([h_angle_g, v_angle_g])
                    else:
                        self.test_images.append(file_path)
                        self.test_angles_r.append([h_angle_r, v_angle_r])
                        self.test_labels.append(idx - 1)
                        self.test_images_t.append(file_path_t)
                        self.test_angles_g.append([h_angle_g, v_angle_g])

    def __getitem__(self, index):
        return (
            self.transform(Image.open(self.train_images[index])),
                torch.tensor(self.train_angles_r[index]),
                self.train_labels[index],
            self.transform(Image.open(self.train_images_t[index])),
                torch.tensor(self.train_angles_g[index]))
        
    def __len__(self):
        return len(self.train_images)
    

In [2]:
transform=tf.Compose([tf.ToTensor(),tf.Resize((64,64),antialias=True)])
# dataset=MyDataset(dir_path='/home/user/Downloads/dataset/0P',transform=transform)
dataset=MyDataset(dir_path='c:\\Users\\hikma\\Downloads/dataset/0P',transform=transform)

In [3]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [4]:
imgs_r,angles_r,labels,imgs_g,angles_g=next(iter(train_loader))
print(imgs_r.shape,angles_r.shape,labels.shape,imgs_g.shape,angles_g.shape)

torch.Size([32, 3, 64, 64]) torch.Size([32, 2]) torch.Size([32]) torch.Size([32, 3, 64, 64]) torch.Size([32, 2])


In [5]:
device='cuda' if torch.cuda.is_available() else 'cpu'
#device='cpu'

In [6]:
from transformer_net import Generator,Discriminator,Generator2
from net import NetD
generator=Generator2()
discriminator=Discriminator()
generator=generator.to(device)
discriminator=discriminator.to(device)
LR = 5e-5
beta1=0.5
beta2=0.999
optimizer_g = torch.optim.Adam(generator.parameters(), LR,betas=(beta1, beta2))
optimizer_d = torch.optim.Adam(discriminator.parameters(), LR,betas=(beta1, beta2))

In [7]:
from loss_network import LossNetwork
loss_network=LossNetwork()
loss_network=loss_network.to(device)




In [8]:
from loss import content_style_loss,adv_loss_d,adv_loss_g,gaze_loss_d,gaze_loss_g,reconstruction_loss

In [9]:
def generator_step(generator,discriminator,loss_network,imgs_r,imgs_t,angles_r,angles_g):
    optimizer_g.zero_grad()
    generator.train()
    discriminator.eval()
    x_g=generator(imgs_r,angles_g)
    x_recon=generator(x_g,angles_r)
    loss_adv=adv_loss_g(discriminator,imgs_r,x_g)
    loss2=content_style_loss(loss_network,x_g,imgs_t)
    loss_p=loss2[0]+loss2[1]
    loss_gg=gaze_loss_g(discriminator,x_g,angles_g)
    loss_recon=reconstruction_loss(generator,imgs_r,x_recon)
    loss=loss_adv+100*loss_p+5*loss_gg+50*loss_recon
    loss.backward()
    optimizer_g.step()
    return loss.item()

In [10]:
def discriminator_step(generator,discriminator,imgs_r,imgs_t,angles_r,angles_g):
    optimizer_d.zero_grad()
    generator.eval()
    discriminator.train()
    x_g=generator(imgs_r,angles_g)
    loss1=adv_loss_d(discriminator,imgs_r,x_g)
    loss2=gaze_loss_d(discriminator,imgs_r,angles_r)
    loss=loss1+5*loss2
    loss.backward()
    optimizer_d.step()
    return loss.item()

In [11]:
from PIL import Image
import numpy as np
def recover_image(img):
    img=img.cpu().numpy().transpose(0, 2, 3, 1)*255
    return img.astype(np.uint8)
def save_debug_image(tensor_orig, tensor_transformed, filename):
    assert tensor_orig.size() == tensor_transformed.size()
    result = Image.fromarray(recover_image(tensor_transformed)[0])
    orig = Image.fromarray(recover_image(tensor_orig)[0])
    new_im = Image.new('RGB', (result.size[0] * 2 + 5, result.size[1]))
    new_im.paste(orig, (0,0))
    new_im.paste(result, (result.size[0] + 5,0))
    new_im.save(filename)

In [12]:
!mkdir -p debug

A subdirectory or file -p already exists.
Error occurred while processing: -p.
A subdirectory or file debug already exists.
Error occurred while processing: debug.


In [13]:

epochs=300
for epoch in range(epochs):
    count=0
    a=torch.tile(torch.tensor([0.,0.]),[32,1])
    b=torch.tile(torch.tensor([-1.,-1.]),[32,1])
    c=torch.tile(torch.tensor([1.,1.]),[32,1])
    #print(a.shape)
    #y=generator(imgs_r.to(device),a.to(device))
    for imgs_r, angles_r, labels, imgs_t, angles_g in train_loader:
        count+=1
        imgs_r=imgs_r.to(device)
        imgs_t=imgs_t.to(device)
        angles_r=angles_r.to(device)
        angles_g=angles_g.to(device)
        l_d=discriminator_step(generator,discriminator,imgs_r,imgs_t,angles_r,angles_g)
        if count%5==0:
            l_g=generator_step(generator,discriminator,loss_network,imgs_r,imgs_t,angles_r,angles_g)
        if count%1000==0:
            #a=torch.tile(torch.tensor([0.,0.]),[32,1])
            ya=generator(imgs_r,a.to(device))
            yb=generator(imgs_r,b.to(device))
            yc=generator(imgs_r,c.to(device))
            save_debug_image(imgs_r, ya.detach(), "./debug/{}_{}_a.png".format(epoch,count))
            save_debug_image(imgs_r, yb.detach(), "./debug/{}_{}_b.png".format(epoch,count))
            save_debug_image(imgs_r, yc.detach(), "./debug/{}_{}_c.png".format(epoch,count))
    print(l_d,l_g)

     
    

-0.14090707898139954 3.1814870834350586
-0.3075118660926819 2.5949172973632812
-0.31422021985054016 2.496328353881836
-0.28986045718193054 1.9234073162078857
-0.20188328623771667 1.5243966579437256
-0.1502615213394165 1.7172067165374756
-0.24869534373283386 1.580897331237793
-0.21149319410324097 1.5009021759033203
-0.34318530559539795 1.5165023803710938
-0.15337936580181122 1.3503401279449463
-0.21691101789474487 1.188339114189148


KeyboardInterrupt: 

In [None]:
a=torch.tile(torch.tensor([0.,0.]),[32,1])
print(a.size(),angles_r.size())
y=generator(imgs_r.to(device),angles_r.to(device))