In [1]:
import pickle
import torch
from torch import nn
import matplotlib.pyplot as plt
from constants import *
import numpy as np
import cv2

In [6]:
data11 = pickle.load(open("./val_range_data/data11.pkl", "rb"))

In [7]:
frames = []
keyboard_data = []
mouse_data = []
cursor_positions = []
reward = []
for data in [data11]:
    frames += data[0]
    keyboard_data += data[1]
    mouse_data += data[2]
    cursor_positions += data[3]

In [8]:
data = []
rewards = []
for i in range(1, len(keyboard_data)):
    action = []
    mouse_dx = 0
    mouse_dy = 0
    leftclick = 0
    space = 0
    w = 0
    a = 0
    s = 0
    d = 0
    shift = 0
    ctrl = 0
    reward = 0
    for entry in mouse_data[i]:
        mouse_dx += entry.lLastX
        mouse_dy += entry.lLastY
        
        if entry.union.structure.usButtonFlags == 1 and leftclick == 0:
            leftclick = 1
        elif entry.union.structure.usButtonFlags == 2 and leftclick == 0:
            leftclick = -1

        if entry.union.structure.usButtonFlags == 0x0100 and reward == 0:
            reward = 1
        
        if entry.union.structure.usButtonFlags == 0x0040 and reward == 0:
            reward = -1
        

    for entry in keyboard_data[i]:

        if entry[0] == 0x57: # W
            if entry[1] == 1:
                w = -1
            elif entry[1] == 0:
                w = 1

        elif entry[0] == 0x41: # A
            if entry[1] == 1:
                a = -1
            elif entry[1] == 0:
                a = 1

        elif entry[0] == 0x53: # S
            if entry[1] == 1:
                s = -1
            elif entry[1] == 0:
                s = 1

        elif entry[0] == 0x44: # D
            if entry[1] == 1:
                d = -1
            elif entry[1] == 0:
                d = 1

        elif entry[0] == 0x20: # Space
            if entry[1] == 1:
                space = -1
            elif entry[1] == 0:
                space = 1
        
        elif entry[0] == 0xA0: # Shift
            if entry[1] == 1:
                shift = -1
            elif entry[1] == 0:
                shift = 1

        elif entry[0] == 0x11: # Ctrl
            if entry[1] == 1:
                ctrl = -1
            elif entry[1] == 0:
                ctrl = 1
    
    #print(mouse_dx, mouse_dy)
    
    # normalize 
    mag = (mouse_dx**2 + mouse_dy**2)**0.5
    if mag != 0:
        mouse_dx /= mag
        mouse_dy /= mag

    action.append(mouse_dx)
    action.append(mouse_dy)
    # action.append(w)
    # action.append(a)
    # action.append(s)
    # action.append(d)
    # action.append(space)
    # action.append(shift)
    # action.append(ctrl)
    # action.append(leftclick)
    # rewards.append(reward)
    # action.append(cursor_positions[i+1][0]/1920-0.5)
    # action.append(cursor_positions[i+1][1]/1080-0.5)
    data.append(action)

In [9]:
labels = np.array(data)

In [10]:
# gray scale each image
for i in range(len(frames)):
    # convert from PIL image to cv2 image
    img = np.array(frames[i])
    # # scale down by 2
    img = cv2.resize(img, (0,0), fx=0.5, fy=0.5)
    # red channel only
    img = img[:,:,0]

    # # # canny edge detection
    # # img = cv2.Canny(img, 100, 200)
    
    # # new_frames.append(img)
    print(i, end='\r')
    frames[i] = img / 255

5001

In [12]:
frames[0].min()

0.0

In [23]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(9216, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
        )

        self.fc2 = nn.Sequential(
            nn.Linear(66, 64),
            nn.ReLU(),
            nn.Linear(64, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
        )


    def forward(self, state, action):
        x = self.convs(state)

        x = self.fc(x)

        x = torch.cat((x, action), dim=1)
        x = self.fc2(x)
        return x

In [14]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(9216, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
        )

        self.fc2 = nn.Sequential(
            nn.Linear(64, 15),
            nn.ReLU(),
            nn.Linear(15, 2),
        )
    
    def forward(self, state):
        x = self.convs(state)

        x = self.fc(x)

        x = self.fc2(x)
        return x

In [15]:
class GAN():
    def __init__(self):


        self.discriminator = Discriminator()
        self.generator = Generator()


In [24]:
gan = GAN()
doptim = torch.optim.Adam(gan.discriminator.parameters())
goptim = torch.optim.Adam(gan.generator.parameters())
dloss = nn.BCEWithLogitsLoss()
glossfn = nn.BCEWithLogitsLoss()

In [17]:
frames = np.array(frames, dtype=np.float32)
labels = np.array(labels, dtype=np.float32)

In [63]:
labels.shape

(674, 10)

In [18]:
BATCH_SIZE = 6

In [19]:

5000 // 6

833

In [25]:
# move to cuda
gan.discriminator = gan.discriminator.to('cuda')
gan.generator = gan.generator.to('cuda')
dloss = dloss.to('cuda')
glossfn = glossfn.to('cuda')

In [30]:
for epoch in range(750):
    for i in range(833):
        img = frames[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]
        real_actions = labels[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]
        

        img = torch.tensor(img).unsqueeze(1).float()
        real_actions = torch.tensor(real_actions).squeeze().float()
        
        #print(img.shape)
        # move to cuda
        img = img.to('cuda')
        real_actions = real_actions.to('cuda')


        # train discriminator
        for j in range(1):
            doptim.zero_grad()
            real = gan.discriminator(img, real_actions)
            dreal_loss = dloss(real, torch.ones_like(real))
            dreal_loss.backward()

            fake_actions = gan.generator(img)
            fake = gan.discriminator(img, fake_actions)
            dfake_loss = dloss(fake, torch.zeros_like(fake))
            dfake_loss.backward()

            doptim.step()

        # train generator

        goptim.zero_grad()
        fake_actions = gan.generator(img)
        fake = gan.discriminator(img, fake_actions)
        gloss_val = glossfn(fake, torch.ones_like(fake))
        gloss_val.backward()
        goptim.step()


        # print losses
        print(epoch, i, gloss_val.item(), dreal_loss.item(), dfake_loss.item(), end="\r")

        # avg_loss = sum(losses) / len(losses)
        # #print(epoch, i, l.item()*100, end="\r")
        # print(epoch, i, l.item()*100, avg_loss*100, end="\r")

0 296 2.3540124893188477 0.22729119658470154 0.101694568991661072

KeyboardInterrupt: 

In [34]:
# evaluate
img = frames[800]
real_actions = labels[800]
img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float()
print(img.shape)
real_actions = torch.tensor(real_actions).squeeze().float()
img = img.to('cuda')
real_actions = real_actions.to('cuda')
fake_actions = gan.generator(img)
fake = gan.discriminator(img, fake_actions)
print(fake_actions)
print(real_actions)

torch.Size([1, 1, 108, 192])
tensor([[ 0.2265, -0.0269]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([0., 1.], device='cuda:0')
