In [None]:
import torch
import math
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import datasets, transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.functional as TF

In [None]:
class PokeDataset(Dataset):
    def __init__(self):
        runs = sorted(glob.glob("poke/train/*"))
        self.images = []
        self.actions = []

        for itter in tqdm(range(1)):
            run = runs[itter]
            actions = np.load(run+"/actions.npy")
            self.px_max = max(actions[:,0])
            self.px_min = min(actions[:,0])
            self.py_max = max(actions[:,1])
            self.py_min = min(actions[:,1])
            for i in range(len(actions)):
                if(actions[i][4] == 1):
                    img_before = Image.open(run + "/img_%04d.jpg"%i)
                    img_after = Image.open(run + "/img_%04d.jpg"%(i+1))
                    img_beforeTensor = TF.to_tensor(img_before)
                    img_afterTensor = TF.to_tensor(img_after)
                    self.images.append([img_beforeTensor,img_afterTensor])
                    px = self.loc(actions[i][0], self.px_max, self.px_min)
                    py = self.loc(actions[i][1], self.py_max, self.py_min)
                    t = self.angle(actions[i][2])
                    l = self.length(actions[i][3])
                    self.actions.append([px,py,t,l])
                    img_before.close()
                    img_after.close()

    def __len__(self):
        return(len(self.actions))

    def __getitem__(self, index):
        return((self.images[index],self.actions[index]))

    def loc(self, p, maxVal, minVal):
        x = np.zeros((20))
        if(p == maxVal):
            x[19] = 1
            return(x)
        x[int((p-minVal)//((maxVal-minVal)/20))] = 1
        return(x)

    def angle(self, theta):
        x = np.zeros((36))
        x[int(theta//(math.pi/13))] = 1
        return(x)

    def length(self, l):
        x = np.zeros((11))
        x[int((l-.01)//12)] = 1
        return(x)

In [None]:
data = PokeDataset()
dataLoader = DataLoader(data, batch_size=2, shuffle=True)

In [None]:
class Inverse(nn.Module):
    def __init__(self):
        super(Inverse, self).__init__()
        #CNN - Alex Net
        self.conv1 = nn.Conv2d(3,96,11,4)
        self.maxPool1 = nn.MaxPool2d(3,2)
        self.conv2 = nn.Conv2d(96,256,5,1,2)
        self.maxPool2 = nn.MaxPool2d(3,2)
        self.conv3 = nn.Conv2d(256,384,3,1,1)
        self.conv4 = nn.Conv2d(384,384,3,1,1)   
        self.conv5 = nn.Conv2d(384,256,3,1,1)
        self.maxPool3 = nn.MaxPool2d(3,2)
        #l 
        self.l1 = nn.Linear(18432,9216)
        self.l2 = nn.Linear(9216,2304)
        self.l3 = nn.Linear(2304,576)
        self.l4 = nn.Linear(576,11)
        #theta
        self.theta1 = nn.Linear(18443,9216)
        self.theta2 = nn.Linear(9216,2304)
        self.theta3 = nn.Linear(2304,576)
        self.theta4 = nn.Linear(576,36)
        #px
        self.px1 = nn.Linear(18468,9216)
        self.px2 = nn.Linear(9216,2304)
        self.px3 = nn.Linear(2304,576)
        self.px4 = nn.Linear(576,20)
        #py
        self.py1 = nn.Linear(18452,9216)
        self.py2 = nn.Linear(9216,2304)
        self.py3 = nn.Linear(2304,576)
        self.py4 = nn.Linear(576,20)       
        
        self.softMax = nn.Softmax(dim=1)
        
    def latentRepresentation(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.maxPool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.maxPool2(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        x = F.relu(x)
        x = self.maxPool3(x)
        x = torch.flatten(x,1)
        return(x)
    
    def l(self,x):
        x = self.l1(x)
        x = F.relu(x)
        x = self.l2(x)
        x = F.relu(x)
        x = self.l3(x)
        x = F.relu(x)
        x = self.l4(x)
        x = self.softMax(x)
        return(x)
    
    def theta(self,x):
        x = self.theta1(x)
        x = F.relu(x)
        x = self.theta2(x)
        x = F.relu(x)
        x = self.theta3(x)
        x = F.relu(x)
        x = self.theta4(x)
        x = self.softMax(x)
        return(x)

    def px(self,x):
        x = self.px1(x)
        x = F.relu(x)
        x = self.px2(x)
        x = F.relu(x)
        x = self.px3(x)
        x = F.relu(x)
        x = self.px4(x)
        x = self.softMax(x)
        return(x)
    
    def py(self,x):
        x = self.py1(x)
        x = F.relu(x)
        x = self.py2(x)
        x = F.relu(x)
        x = self.py3(x)
        x = F.relu(x)
        x = self.py4(x)
        x = self.softMax(x)
        return(x)
    
    def forward(self, I_0,I_1):
        x_0 = self.latentRepresentation(I_0)
        x_1 = self.latentRepresentation(I_1)
        concat = torch.cat((x_0,x_1),1)
        l = self.l(concat)
        thetaInput = torch.cat((concat,l),1)
        theta = self.theta(thetaInput)
        pxInput = torch.cat((concat,theta),1)
        px = self.px(pxInput)
        pyInput = torch.cat((concat,px),1)
        py = self.py(pyInput)
        return(x_0,x_1,px,py,theta,l)
    
class Forward(nn.Module):
    def __init__(self):
        super(Forward, self).__init__()
        self.l1 = nn.Linear(9303,9300)
        self.l2 = nn.Linear(9300,9250)
        self.l3 = nn.Linear(9300,9250)
        
    def forward(self,x_0,px,py,theta,l):
        x = torch.cat((l,theta,px,py,x_0),1)
        x = self.l1(x)
        x = F.relu(x)
        x = self.l2(x)
        x = F.relu(x)
        x = self.l3(x)
        return(x)

In [None]:
fwd = Forward()
inv = Inverse()
fwd = fwd.float()
inv = inv.float()

In [None]:
optimizer = optim.Adam(list(fwd.parameters()) + list(inv.parameters()))

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dataLoader, 0):
        # get the inputs; data is a list of [inputs, labels]
        images, actions = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        
        x_0,x_1,px,py,theta,l = inv(images[0],images[1])
        x_1Hat = fwd(x_0.double(),actions[0].double(),actions[1].double(),actions[2].double(),actions[3].double())
        
        ceLoss = nn.CrossEntropyLoss()
        l1Loss = nn.L1Loss()
        loss = ceLoss(px,actions[0]) + ceLoss(py,actions[1]) + ceLoss(theta,actions[2]) + ceLoss(l,actions[3]) + 0.1*l1Loss(x_1,x_1Hat)

        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')