In [1]:
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
import torch.nn as nn
import random
import numpy as np
import math
from tqdm.auto import tqdm
import torchmetrics

In [2]:
experiment_config = {
    'obj_size': 4,
    'obs_size': 16,
}
model_config = {
    'd_model': 128,
    'n_head': 8,
    'num_layers': 6,
    'num_obj': 4,
}
training_config = {
    'batch_size':512,
    'lr': 5e-5,
    'max_step': 3000
}

In [3]:
device = torch.device('cuda')

In [4]:
class SymbolicWorldIterableDataset(IterableDataset):

    def __init__(self, num_objects, observation_size):
        self.num_objects = num_objects
        self.observation_size = observation_size
        pass

    def generate_observation(self):
        observation = torch.randint(0, self.num_objects, (self.observation_size,))

        parity_1 = ((observation == 1).sum())%2
        return {'observation':observation, 'parity_1':parity_1}

    def __iter__(self):
        while(True):
            yield self.generate_observation()

iterable_dataset = SymbolicWorldIterableDataset(experiment_config['obj_size'], experiment_config['obs_size'])
loader = DataLoader(iterable_dataset, batch_size=training_config['batch_size'])

test_loaders_list = []
for observation_size in [4,8,16,32,64]:
    test_loader_name = f'observation_size_{observation_size}'
    test_loader_data = DataLoader(SymbolicWorldIterableDataset(experiment_config['obj_size'], observation_size),\
        batch_size=training_config['batch_size'])
    test_loaders_list.append({
        'name': test_loader_name,
        'data_loader': test_loader_data
    })

if True:
    for a in loader:
        print(a['observation'][0])
        print(a['parity_1'][0])
        break


tensor([3, 2, 1, 0, 2, 2, 0, 3, 3, 2, 1, 0, 3, 0, 3, 3])
tensor(0)


In [5]:
class MySimpleModel(nn.Module):
    
    def __init__(self, model_config):
        super(MySimpleModel, self).__init__()
        d_model = model_config['d_model']
        n_heads = model_config['n_head']
        self.n_layers = model_config['num_layers']
        n_objects = model_config['num_obj']
        n_clss = 5
        self.d_model = d_model
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.token_embedding = torch.nn.Embedding(n_objects, d_model)
        self.clss_embedding = torch.nn.Embedding(n_clss, d_model)
        self.fc1 = torch.nn.Linear(d_model, 2)
        
    def forward(self, x, clss):
        '''
        x : long tensor: [batch, seq_len]
        clss: long tensor: [batch, 1]
        return: tensor: [batch, 2]
        '''
        temp = x
        temp = self.token_embedding(temp)
        clss_embedding = self.clss_embedding(clss)
        temp = torch.cat((clss_embedding, temp), dim=1)
        for _ in range(self.n_layers):
            temp = self.encoder_layer(temp)
        temp = temp[:, 0, :]
        temp = self.fc1(temp)
        return temp

In [6]:
model = MySimpleModel(model_config).to(device)
loss_module = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=training_config['lr'])
f1_module = torchmetrics.F1Score(num_classes=2).to(device)

In [7]:
def eval(test_loader):
    with torch.no_grad():
            test_f1s = []
            for j, batch in enumerate(test_loader):
                x = batch['observation'].to(device)
                clss = torch.ones((x.shape[0], 1)).long().to(device)
                y = batch['parity_1'].to(device)
                output = model(x, clss)
                preds = output.argmax(dim=1)
                test_f1s.append(f1_module(preds, y).item())
                if j > 10:
                    break
            f1 = sum(test_f1s)/len(test_f1s)
            return round(f1, 4)

In [8]:
loss_history = []
max_steps = training_config['max_step']
log_every_steps = 200
for i, batch in tqdm(enumerate(loader), total=max_steps):
    x = batch['observation'].to(device)
    clss = torch.ones((x.shape[0], 1)).long().to(device)
    y = batch['parity_1'].to(device)
    optimizer.zero_grad()
    output = model(x, clss)
    loss = loss_module(output, y)
    loss_history.append(loss.item())
    loss.backward()
    optimizer.step()
    if i > max_steps:
        break
    if i % log_every_steps == 0 :
        print("step {}".format(i))
        print("loss: {:.2e}".format(sum(loss_history)/len(loss_history)))
        loss_history = []
        train_f1 = f1_module(output.argmax(dim=1), y).item()
        print('train f1: ', round(train_f1,4))
        for test_loader in test_loaders_list:
            test_acc = eval(test_loader['data_loader'])
            print(test_loader['name'], " f1_score: " , test_acc)
        print('------------------------')

  0%|          | 0/3000 [00:00<?, ?it/s]

step 0
loss: 7.32e-01
train f1:  0.5059
observation_size_4  f1_score:  0.4749
observation_size_8  f1_score:  0.5018
observation_size_16  f1_score:  0.5096
observation_size_32  f1_score:  0.491
observation_size_64  f1_score:  0.4917
------------------------
step 200
loss: 6.98e-01
train f1:  0.4824
observation_size_4  f1_score:  0.5252
observation_size_8  f1_score:  0.5086
observation_size_16  f1_score:  0.4922
observation_size_32  f1_score:  0.5085
observation_size_64  f1_score:  0.4992
------------------------
step 400
loss: 6.96e-01
train f1:  0.5098
observation_size_4  f1_score:  0.4852
observation_size_8  f1_score:  0.4997
observation_size_16  f1_score:  0.505
observation_size_32  f1_score:  0.5062
observation_size_64  f1_score:  0.4959
------------------------
step 600
loss: 6.95e-01
train f1:  0.4824
observation_size_4  f1_score:  0.4849
observation_size_8  f1_score:  0.4932
observation_size_16  f1_score:  0.5023
observation_size_32  f1_score:  0.4897
observation_size_64  f1_scor

KeyboardInterrupt: 

In [None]:
a = model_config.copy()
a['b']=  0

In [None]:
a

{'d_model': 32, 'n_head': 4, 'num_layers': 12, 'num_obj': 8, 'b': 0}