In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as tf
import PIL.Image

kappa = PIL.Image.open('data/kappa.png')
kappa_tensor = tf.to_tensor(kappa)[1]

j = PIL.Image.open('data/J.png')
j_tensor = tf.to_tensor(j)[1]

dd = PIL.Image.open('data/dd.png')
dd_tensor = tf.to_tensor(dd)[1]

maps = PIL.Image.open('data/maps.png')
maps_tensor = tf.to_tensor(maps)[1]

image_shape = torch.tensor([34, 28])
image_size = image_shape[0] * image_shape[1]
offset_size = image_shape[1].item()

classes_count = 2

class Kappa(nn.Module):
    def __init__(self):
        super(Kappa, self).__init__()

        self.fc1 = torch.nn.Linear(image_size, classes_count)
        self.memory = torch.nn.Parameter(torch.zeros(classes_count, image_size))
        
        self.fc2 = torch.nn.Linear(image_size + classes_count, offset_size)
        self.offsets = torch.stack([
                           torch.tensor(range(image_size)).reshape(*image_shape).roll(i, 1).flatten()
                           for i in range(offset_size)])

    def forward(self, image):
        x = image.view(-1, image_size)
        x = self.fc1(x)
        class_ = F.softmax(x)
        
        recalled_image = class_ @ self.memory
        
        z = image.view(-1, image_size)
        z = torch.cat([z, class_], 1)
        z = self.fc2(z)
        offset = F.softmax(z)
        
        recalled_offset_image = []
        for choice, image in zip(offset, recalled_image.reshape(-1, image_size)):
            recalled_offset_image.append(choice @ image[self.offsets])
        recalled_offset_image = torch.stack(recalled_offset_image)
        
        # recalled_offsets = recalled_image.T[self.offsets].T
        # print(recalled_offsets.shape)
        # recalled_offset_image = torch.bmm(F.softmax(z).reshape(-1, 1, offset_size), recalled_offsets)
        
        return class_, offset, recalled_offset_image.view(-1, *image_shape)
        
    def determine_class(self, image):
        class_, offset_, recalled = self.forward(image)
        return class_
        
        x = image.view(-1, image_size)
        x = self.fc1(x)
        
        return F.softmax(x)

    def determine_offset(self, image):
        class_, offset, recalled = self.forward(image)
        return offset
        
        class_ = self.determine_class(image)
        z = image.view(-1, image_size)
        z = torch.cat([z, class_], 1)
        z = self.fc2(z)
        return F.softmax(z)

def train(model, device, optimizer, epochs):
    model.train()
    
    images = torch.stack([j_tensor, kappa_tensor]).to(device)
    
    for _ in range(epochs):
        indices = torch.randint(classes_count, size=(16,))
        # indices = torch.randint(2, size=(16,))
        data, target = images[indices], images[indices]
        
        for i in range(len(data)):
            offset_in = torch.randint(image_shape[1], size=(1,)).item()
            data[i] = data[i].roll(offset_in, 1)
            target[i] = target[i].roll(offset_in, 1)
    	
        optimizer.zero_grad()
        
        # log_choice, recalled_image = model(data)
        class_, offset, recalled_image = model(data)
        classification_loss = F.cross_entropy(class_.log(), indices)
        memory_loss = F.mse_loss(recalled_image, target)
        
        # loss = classification_loss + memory_loss
        loss = memory_loss
        
        loss.backward()
        optimizer.step()

In [None]:
device = 'cpu'
model = Kappa().to(device)
optimizer = torch.optim.Adam(model.parameters())
train(model, device=device, optimizer=optimizer, epochs=2*2000)

In [None]:
results = model.determine_class(torch.stack([j_tensor, dd_tensor, maps_tensor]))

(100*results).round() / 100

## IMPORTS

In [None]:
!pip install gym-retro

In [None]:
!python3 -m retro.import roms

In [None]:
!pip install pyspark

In [None]:
import retro
import torch
import random
import numpy as np
import IPython.display
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as tf
import PIL.Image

## ACTIONS

In [None]:
actions_codes = {
    'UP':        [0, 0, 0, 0, 1, 0, 0, 0, 0],
    'DOWN':      [0, 0, 0, 0, 0, 1, 0, 0, 0],
    'LEFT':      [0, 0, 0, 0, 0, 0, 1, 0, 0],
    'RIGHT':     [0, 0, 0, 0, 0, 0, 0, 1, 0],
    'NONE':      [0, 0, 0, 0, 0, 0, 0, 0, 0],
    'B':         [1, 0, 0, 0, 0, 0, 0, 0, 0],
    'A':         [0, 0, 0, 0, 0, 0, 0, 0, 1],
}

actions = list(actions_codes.keys())

codes_actions = {tuple(code): action for action, code in actions_codes.items()}

## ENVIRONMENT2

In [None]:
class Environment2:
    def __init__(self):
        self.environment = retro.make(game='SuperMarioBros-Nes')
        # self.environment = retro.make(game='NinjaGaiden-Nes')
        # self.environment = retro.make(game='Airstriker-Genesis')
        
        self.blocks_seen = []
        self.blocks_seen_urls = []
        
        # self.frames_all = []
        self.actions_all = []

        self.encodings = set()
        self.encodings_frame = set()
        
        asymmetric = torch.linspace(0.5, 1.5, 16*16*3)**3
        asymmetric = asymmetric.numpy()
        
        aa = torch.tensor(asymmetric).reshape(16, 16, 3).unsqueeze(0).permute(0, 3, 1, 2)
        filter_ = torch.tensor(aa, dtype=torch.float)
        self.filter_ = filter_ / 16 / 16 / 3 / 255

        self.frame = self.environment.reset()
        # self.frames_all.append(self.frame)


    def step(self, action, commitment_interval):
        for _ in range(commitment_interval):
            self.frame, reward, is_done, information = self.environment.step(action)
        
        # self.blocks_identify(self.frame)
        # self.frames_all.append(self.frame)
        self.actions_all.append(action)

        return self.frame, reward, is_done, information

    def close(self):
        self.environment.render(close=True)
        self.environment.close()

    __del__ = close

    
    def blocks_identify_all(self):
        t = torch.tensor(np.stack(self.frames_all)).float()
        images = t.permute(0, 3, 1, 2)
        
        # id_ = str(uuid.uuid4())
        # np.savez_compressed(f'/tmp/{id}.npz', np.stack(self.frames_all).astype('uint8'))

        output = F.conv2d(input=images,
                          weight=self.filter_,
                          stride=16)

        output = output[:, :, 4:-1]
        
        return output[:-1], self.actions_all
    
    def frame_encode(self, frame):
        t = torch.tensor(frame).unsqueeze(0).float()
        images = t.permute(0, 3, 1, 2)

        output = F.conv2d(input=images,
                          weight=self.filter_,
                          stride=16)

        output = output[:, :, 4:-1]
        
        return output

## GENERATE

In [None]:
# %%time

import numpy as np
import random

import warnings
warnings.filterwarnings('ignore')

def generate_play(step_count, agent, return_frames=False, random_seed=None):
    environment = Environment2()

    frames_all = torch.zeros(step_count + 1, 224, 240, 3, dtype=torch.uint8)
    # frames_all = np.zeros((step_count + 1, 224, 240, 3), dtype=np.uint8)
    
    random.seed(random_seed)
    
    frame = environment.frame
    frames_all[0] = torch.tensor(frame)
    for index in range(step_count or 999999999):
        if agent:
            encoding = environment.frame_encode(frame).flatten().unsqueeze(0)
            action = agent.predict(encoding)[0]
        else:
            action = actions[random.randint(0, len(actions)-1)]
            # action = actions[4]
        
        action_code = actions_codes[action]
        frame, reward, is_done, information = environment.step(action_code, 6)

        frames_all[index + 1] = torch.tensor(frame)

        if information['lives'] == 1:
            break
        
    environment.close()
    
    # encodings, actions_all_codes = environment.blocks_identify_all()
    # actions_all = [codes_actions[tuple(action_code)] for action_code in actions_all_codes]
    
    result = {
        # 'EncodingsUniqueCount': len(encodings.unique()),
        # 'FramesCount': len(encodings),
        # 'Encodings': encodings,
        # 'Actions': actions_all,
        'FrameLast': frame
    }
    
    if return_frames:
        # return {**result, **{'Frames': environment.frames_all[:-1]}}
        return {**result, **{'Frames': frames_all}}
    else:
        return result

In [None]:
frames_acc = []

In [None]:
%%time

for index in range(1):
    result = generate_play(step_count=1600, agent=None, return_frames=True)
    # print(result['EncodingsUniqueCount'], result['FramesCount'])
    plt.imshow(result['Frames'][-3]);
    frames_acc.append(result['Frames'])
    print(f"{index}.", end='')

In [None]:
import pyspark

spark = (pyspark.sql.SparkSession
         .builder
         .master("local[*]")
         .config("spark.executor.memory", "200g")
         .config("spark.driver.memory", "200g")
         .config("spark.memory.offHeap.enabled", True)
         .config("spark.memory.offHeap.size","200g")
         .config("spark.driver.maxResultSize", "200g")
         # .appName('lecture')
         .getOrCreate())

sc = spark.sparkContext

In [None]:
def partition(frames, width):
    frames = frames.permute(0, 3, 1, 2).float()
    # frames = frames.to(torch.float32)

    unfolded = F.unfold(input=frames,
                        kernel_size=(width, width),
                        stride=width)

    imagez = unfolded.permute(0, 2, 1)
    
    row = 4
    imagez = imagez[:, row*15:-1*15]
    imagez = imagez.reshape(-1, 210 - row*15 - 1*15, 3, width, width).permute(0, 1, 3, 4, 2)
    imagez = imagez.to(torch.uint8)

    imagez = imagez.reshape(-1, width, width, 3).unsqueeze(0)
    
    asymmetric = torch.linspace(0.5, 1.5, 16*16*3)**3
    asymmetric = asymmetric / 16 / 16 / 3 / 255
    
    unique = torch.unique(imagez, dim=1)
    unique = unique.squeeze()
    # encodings = unique.reshape(-1, 16*16*3).float() @ asymmetric

    # z1 = [Image(image, display_scale=2) for image in unique]

    # sorted_ = sorted(zip(encodings, z1), key=lambda x: x[0])
    # z1 = [b for a, b in sorted_]

    # import pickle
    # pickle.dump([len(z1), z1], open('data/active.pickle', 'wb'))
    
    return unique

In [None]:
frames = generate_play(1600, agent=None, return_frames=True)['Frames']

In [None]:
partition(frames, 16).shape

In [None]:
frames2 = []

In [None]:
%%time

for _ in range(1):
    frames = sc.parallelize([random.randint(0, 9999999999999999) for _ in range(60)]) .map (
        lambda x: partition(generate_play(1600, agent=None, return_frames=True, random_seed=x)['Frames'], 16)
    )

    frames2.extend(frames.collect())

In [None]:
# 1-1940
# 10-7261
# 40-17073
# 80-30214
# 128-32026
# 160-38604
# 180-38207
# 240-42112
# 300-46418
# 360-50823
# 420-53841
# 480-55910
# 540-59669
# 600-62157
# 660-64616
# 720-66916
# 780-69911
# 840-71086
# 900-72458
# 960-73631
# 1020-74999

In [None]:
torch.cat(frames2).shape

In [None]:
import pickle
pickle.dump(torch.cat(frames2), open('data/frames2_1720295_16_16_3.tensor.pickle', 'wb'))

In [None]:
import pickle
frames2 = pickle.load(open('data/frames2_16_16_3.tensor.pickle', 'rb'))

In [None]:
torch.cat(blocks).shape

In [None]:
from forward import *

asymmetric = torch.linspace(0.5, 1.5, 16*16*3)**3
asymmetric = asymmetric / 16 / 16 / 3 / 255

uniques = torch.unique(torch.cat(frames2), dim=0)
encodings = uniques.reshape(-1, 16*16*3).float() @ asymmetric

z1 = [Image(image, display_scale=2) for image in uniques]

sorted_ = sorted(zip(encodings, z1), key=lambda x: x[0])
z1 = [b for a, b in sorted_]

import pickle
pickle.dump([len(z1), z1], open('data/active.pickle', 'wb'))

In [None]:
len(frames2), len(z1)

In [None]:
uniques = torch.unique(torch.cat(frames2), dim=0)
encodings = uniques.reshape(-1, 16*16*3).float() @ asymmetric

sorted_ = sorted(zip(encodings, uniques), key=lambda x: x[0])
z1 = [b for a, b in sorted_]
z1 = torch.stack(z1)


In [None]:
ims = z1.reshape(1, -1, 16*16, 3)

pickle.dump([len(ims), [Image(im) for im in ims]], open('data/active.pickle', 'wb'))

In [None]:
57599232/(16*16*3)

In [None]:
%%time

# frames = result['Frames'][-3]
# frames = torch.tensor(frames).unsqueeze(0)

# frames = torch.tensor(result['Frames'])
# frames = result['Frames']
print(frames.shape)
frames = frames.permute(0, 3, 1, 2)
print(frames.shape)
frames = frames.to(torch.float32)

unfolded = F.unfold(input=frames,
                    kernel_size=(16, 16),
                    stride=16)

print("unf", unfolded.shape)

imagez = unfolded.permute(0, 2, 1)
print("perm", imagez.shape)
row = 4
imagez = imagez[:, row*15:-1*15]
imagez = imagez.reshape(-1, 210 - row*15 - 1*15, 3, 16, 16).permute(0, 1, 3, 4, 2)
imagez = imagez.to(torch.uint8)

imagez = imagez.reshape(-1, 16, 16, 3).unsqueeze(0)

print(imagez.shape)

# z1 = [[Image(im, display_scale=2) for im in frame] for frame in imagez]

import pickle
# pickle.dump(z1, open('../interface/data/active.pickle', 'wb'))

In [None]:
asymmetric = torch.linspace(0.5, 1.5, 16*16*3)**3
# asymmetric = asymmetric.reshape(16, 16, 3)
asymmetric = asymmetric / 16 / 16 / 3 / 255

In [None]:
%%time

from forward import *

unique = torch.unique(imagez, dim=1)
unique = unique.squeeze()
# unique.sort()
print(unique.shape)
encodings = unique.reshape(-1, 16*16*3).float() @ asymmetric

z1 = [Image(image, display_scale=2) for image in unique]

sorted_ = sorted(zip(encodings, z1), key=lambda x: x[0])
z1 = [b for a, b in sorted_]

import pickle
pickle.dump([len(z1), z1], open('data/active.pickle', 'wb'))