In [1]:
import os
import io
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

In [3]:
class t2iDS(Dataset):
    def __init__(self, filename, split):
        self.filename = filename
        self.split = split
        self.dataset = h5py.File(self.filename, 'r')
        self.keys = [str(k) for k in self.dataset[self.split].keys()]
    
    def __len__(self):
        return len(self.dataset[self.split])
    
    def __getitem__(self, idx):
        key = self.keys[idx]
        data = self.dataset[self.split][key]
        
        right_im = bytes(np.array(data['img']))
        right_embed = np.array(data['embeddings'], dtype=float)
        wrong_im = bytes(np.array(self.get_wrong_im(data['class'])))
        random_embed = np.array(self.get_random_embed())
        
        right_im = Image.open(io.BytesIO(right_im)).resize(64,64)
        right_im = self.validate_im(right_im)
        wrong_im = Image.open(io.BytesIO(wrong_im)).resize(64,64)
        wrong_im = self.validate_im(wrong_im)
        
        txt = np.array(data['txt']).astype(str)
        
        item = {
            'right_im': torch.FloatTensor(right_im).sub_(127.5).div_(127.5),
            'wrong_im': torch.FloatTensor(wrong_im).sub_(127.5).div_(127.5),
            'right_embed': torch.FloatTensor(right_embed),
            'random_embed': torch.FloatTensor(random_embed),
            'txt': str(txt)
        }
        
        return item
    
    def get_wrong_im(self, clas):
        idx = np.random.randint(len(self.keys))
        key = self.keys[idx]
        data = self.dataset[self.split][key]
        _clas = data['class']
        return data['image'] if (_clas != clas) else self.get_wrong_im(clas)
    
    def get_random_embed(self):
        idx = np.random.randint(len(self.keys))
        key = self.keys[idx]
        data = self.dataset[self.split][key]
        return data['embeddings']
    
    def validate_im(self, im):
        im = np.array(im, dtype=float)
        if len(im.shape) < 3:
            rgb = np.empty((64,64,3), dtype=np.float32)
            rgb[:,:,0] = im
            rgb[:,:,1] = im
            rgb[:,:,2] = im
            im = rgb
        return im.transpose(2,0,1)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.imsize = 64
        self.channels = 3
        self.noise_dim = 100
        self.ngf = 64
        self G = nn.Sequential(
            nn.ConvTranspose2d(self.noise_dim, self.ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.ngf, self.channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.G(x)

class Discriminator(nn.Module):
    def __init__(self):