# Geometric Scene Sim Turing Test

## Set Game Parameters

In [1]:
'''
LANGUAGE PARAMETERS
'''
SEQ_LEN = 3
VOCAB_SIZE = 9

'''
DATASET PARAMETERS
'''
IMG_DIM = 32
MIN_SHAPES = 1
MAX_SHAPES = 3
OUTLINE = (255, 255, 255)
ALEC_MODE = True

'''
MODEL TRAINING PARAMETERS
'''
BATCHES_TRAIN = 1
BATCH_SIZE = 1024

'''
GAMEPLAY PARAMETERS
'''
SKIP_INTRO = False    # whether to skip introduction and instructions
NUM_GAMES = 24        # number of rounds to play for
ANSWER_TIME = 3       # how much time the human has to answer a question
BREAK_TIME = 2        # how much time between adjacent questions

## Loading Dataset & Analysis Utils

In [2]:
!wget -O shapedata.py https://raw.githubusercontent.com/interactive-intelligence/emergent-lang/main/shapedata.py
import shapedata
import importlib
importlib.reload(shapedata)

## Training

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from PIL import Image

from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2

from IPython.display import display
from tqdm.notebook import tqdm

# "borrowed" from https://github.com/zalandoresearch/pytorch-vq-vae
# with slight adjustments to make it work for plain (non-2d) vectors

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        # return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
        return loss, quantized.contiguous(), perplexity, encodings

class VisionModule(nn.Module):
    def __init__(self):
        super(VisionModule, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, 5),
#             nn.LeakyReLU(),
            nn.SiLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3),
#             nn.LeakyReLU(),
            nn.SiLU(),
            nn.MaxPool2d(2),
            
            nn.Flatten(),
            
            nn.Linear(32*6*6, 64),

            nn.BatchNorm1d(64),
        )

    def forward(self, x):
        return self.cnn(x)


class Model(nn.Module):
    def __init__(self, no_vq=False):
        super(Model, self).__init__()

        self.no_vq = no_vq
        
        self.cnn = VisionModule()
        self.encoderRNN = nn.GRU(64, 64, 1)
        # self.encoderRNN = nn.LSTM(64, 64, 1)
        self.vq = VectorQuantizerEMA(VOCAB_SIZE, 64, 0.25, 0.95)
        self.decoderRNN = nn.GRU(64, 64, 1)
        self.fc = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def speak(self, z, max_len=SEQ_LEN):

        seq = [ torch.zeros_like(z).unsqueeze(0) ]
        h_n = z.unsqueeze(0)

        for i in range(max_len):
            # output, (h_n, c_n) = self.encoderRNN(seq[-1], (h_n, c_n))
            output, h_n = self.encoderRNN(seq[-1], h_n)
            seq.append(output)

        return torch.cat(seq[1:], 0)
    
    def listen(self, seq, z):
        output, h_n = self.decoderRNN(seq, z.unsqueeze(0))
        fc_in = torch.cat([h_n.squeeze(0), z], axis=1)
        return self.fc(fc_in)
    
    def forward(self, x):
        z = self.cnn(x)
        
        seq = self.speak(z)

        batch_size = x.shape[0]
        assert batch_size % 2 == 0
        left = slice(None, batch_size//2)
        right = slice(batch_size//2, None)

        if self.no_vq:
            out1 = self.listen(seq[:, left], z[right])
            out2 = self.listen(seq[:, right], z[left])
            return torch.cat([out1, out2])
        
        loss, q_seq, _, enc = self.vq(seq)

        out1 = self.listen(q_seq[:, left], z[right])
        out2 = self.listen(q_seq[:, right], z[left])

        seq_len = q_seq.shape[0]
        enc = enc.reshape(seq_len, batch_size, -1)

        return loss, torch.cat([out1, out2]), enc
    
batch_size = BATCH_SIZE
device = 'cpu'

model = Model().to(device)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

data = shapedata.ShapeData(batch_size=batch_size//2, im_size=IMG_DIM, min_shapes=MIN_SHAPES, max_shapes=MAX_SHAPES,
                           alec_mode = ALEC_MODE, outline = OUTLINE)

losses = []
bar = tqdm(range(BATCHES_TRAIN))
for batch in bar:

    # get data from function
    (x1, x1_shapes), (x2, x2_shapes), y = data.create_batch()
    X, y = shapedata.to_pytorch_inputs(x1, x2, y)
#     X = torch.from_numpy(X).to(torch.float32).to(device)
#     y = torch.from_numpy(y).to(torch.float32).to(device)

    # Compute prediction error
    vq_loss, pred, enc = model(X)
    pred = pred.flatten()

    # vq_loss *= 0.1 # weird magic number I added for some reason
    loss = criterion(pred, y) + vq_loss

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    bar.set_postfix({'loss': f'{loss.item():>6f}'})
    losses.append(loss.item())
    
plt.figure(figsize=(10, 4), dpi=400)
plt.plot(losses)
plt.show()
plt.close()

## Game
Press `Shift` + `Enter` to execute the game. Jupyter Notebooks should auto-scroll for you, depending on the viewer.

In [4]:
from threading import Thread
import os
import time


"""
INTRODUCTION TEXT DYNAMICS
"""

def clearConsole():
    print('\n'*15)

welcome_message = """
 _ _ _     _                      _          _____                                 _____ __    _____ 
| | | |___| |___ ___ _____ ___   | |_ ___   |  |  |_ _ _____ ___ ___    _ _ ___   |   __|  |  |     |
| | | | -_| |  _| . |     | -_|  |  _| . |  |     | | |     | .'|   |  | | |_ -|  |   __|  |__| | | |
|_____|___|_|___|___|_|_|_|___|  |_| |___|  |__|__|___|_|_|_|__,|_|_|   \_/|___|  |_____|_____|_|_|_|
"""

horizontal_rule = '_____________________________________________________________________________________________________'

print(welcome_message)
time.sleep(1)
print(horizontal_rule)
time.sleep(1)

if not SKIP_INTRO

    instructions = """
    INSTRUCTIONS

    You will be presented with two images of a scene. This scene will contain several objects. Each object
    is either {red, green, blue} in color and either a {square, circle, triangle}. The objects may be in
    different states of rotation and/or overlap.

    Two geometric scenes are considered to be the same if they feature exactly the same set of objects, even if
    the objects themselves are arranged in different ways.

    A human will play against an Emergent Language Model. The human is given 3 seconds to enter either:
    - 0 to indicate that they believe two scenes are different
    - 1 to indicate that they believe two scenes are the same

    After each guess, the answer will be displayed for two seconds. Then, the next question will be displayed.
    In total, the game will run for 2 minutes nonstop, cycling through 2 * 60 / 5 = 24 continuous questions.

    The human wins if they score a higher accuracy than the Emergent Language Model.
    """

    for line in instructions.split('\n'):
        print(line)
        time.sleep(0.5)
    
    print('\nWhat is the name of the human challenging the Emergent Language Model?')
    name = input(':: ')

    print(f'\nAre you ready to play, {name}? (Enter anything to continue)')
    proceed = input(':: ')

def timer(seconds):
    while seconds > 0:
        print(f'SECONDS REMAINING: {seconds}...', end = "\r")
        seconds -= 1
        time.sleep(1)
    clearConsole()

# prepare to play
print('\nGet ready to play...')
timer(3)
clearConsole()

# track human vs model score
score = 0
model_score = 0

# create dataset and obtain model predictions ahead of time
(x1, x1_shapes), (x2, x2_shapes), y = data.create_batch()
X, y = shapedata.to_pytorch_inputs(x1, x2, y)
loss, preds, enc = model.forward(X)

# execute game play
for game in range(NUM_GAMES):
    
    print(f'GAME {game+1}/{NUM_GAMES}')
    
    # obtain model prediction
    pred = preds[game]
    
    # show question
    plt.subplot(1, 2, 1)
    plt.imshow(x1[game])
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.imshow(x2[game])
    plt.axis('off')
    plt.show()
    
    # obtain user answer w/ time limit
    answer = None
    def check():
        timer(ANSWER_TIME)
        if answer != None:
            return answer
        print("You were too slow. Press enter to continue.")
    Thread(target = check).start()
    answer = input("Answer (0/1): ")
    
    # border
    print('-'*50 + '\n')
    
    # calculate and print round results
    print('END OF ROUND STATS:')
    if len(answer) > 0 and answer[0] in ['0', '1'] and int(answer[0]) == int(y[game]):
        print('Your answer was correct!')
        score += 1
    else:
        print('Your answer was incorrect.')
    print(f'Model predicted: {np.round(pred.mean().item(), 4)}')
    if np.round(pred.mean().item()) == y[game]:
        print(f'The model was correct!')
        model_score += 1
    else:
        print('The model was incorrect!')
    print(f'The correct answer was {y[game]}.')        
    print(f'YOUR CURRENT ACCURACY:  {score / (game + 1)}')
    print(f'MODEL CURRENT ACCURACY: {model_score / (game + 1)}')
    print('\n' + '-'*50 + '\n')
    
    # next question
    print('Next question in...')
    timer(BREAK_TIME)
    clearConsole()
    
# print final information
print('\n' + '-'*50 + '\n')
print('GAME FINISHED! FINAL STATS:')
print(f'YOUR ACCURACY:  {score / (game + 1)} ({score} / {NUM_GAMES})')
print(f'MODEL ACCURACY: {model_score / (game + 1)} ({model_score} / {NUM_GAMES})')