In [36]:
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
import torch.nn as nn
import random
import numpy as np
import math
from tqdm.auto import tqdm
import torchmetrics

In [37]:
device = torch.device('cuda')

In [38]:
class SymbolicWorldIterableDataset(IterableDataset):

    def __init__(self, num_objects=32, observation_size=8):
        self.num_objects = num_objects
        self.observation_size = observation_size
        pass

    def generate_observation(self):
        observation = torch.randint(0, self.num_objects, (self.observation_size,))
        exist_1 = int(1 in observation)
        return {'observation':observation, 'exist_1':exist_1}

    def __iter__(self):
        while(True):
            yield self.generate_observation()

iterable_dataset = SymbolicWorldIterableDataset()
loader = DataLoader(iterable_dataset, batch_size=1024)

test_loaders_list = []
for observation_size in [4,16,32,64]:
    test_loader_name = f'observation_size_{observation_size}'
    test_loader_data = DataLoader(SymbolicWorldIterableDataset(observation_size=observation_size), batch_size=1024)
    test_loaders_list.append({
        'name': test_loader_name,
        'data_loader': test_loader_data
    })

test_loader_nobj_16 = DataLoader(SymbolicWorldIterableDataset(num_objects=16), batch_size=1024)
test_loader_nobj_4 = DataLoader(SymbolicWorldIterableDataset(num_objects=4), batch_size=1024)

if False:
    for a in loader:
        print(a['observation'][0])
        print(a['exist_1'][0])
        break


In [39]:
class MySimpleModel(nn.Module):
    
    def __init__(self):
        super(MySimpleModel, self).__init__()
        d_model = 32
        n_heads = 8
        n_layers = 1
        n_objects = 32
        n_clss = 5
        self.d_model = d_model
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.transformer_layer = torch.nn.TransformerEncoder(self.encoder_layer, n_layers)
        self.token_embedding = torch.nn.Embedding(n_objects, d_model)
        self.clss_embedding = torch.nn.Embedding(n_clss, d_model)
        self.fc1 = torch.nn.Linear(d_model, 2)
        
    def forward(self, x, clss):
        '''
        x : long tensor: [batch, seq_len]
        clss: long tensor: [batch, 1]
        return: tensor: [batch, 2]
        '''
        temp = x
        temp = self.token_embedding(temp)
        clss_embedding = self.clss_embedding(clss)
        temp = torch.cat((clss_embedding, temp), dim=1)
        temp = self.transformer_layer(temp)
        temp = temp[:, 0, :]
        temp = self.fc1(temp)
        return temp

In [40]:
model = MySimpleModel().to(device)
loss_module = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
f1_module = torchmetrics.F1Score(num_classes=2).to(device)

In [41]:
def eval(test_loader):
    with torch.no_grad():
            test_f1s = []
            for j, batch in enumerate(test_loader):
                x = batch['observation'].to(device)
                clss = torch.ones((x.shape[0], 1)).long().to(device)
                y = batch['exist_1'].to(device)
                output = model(x, clss)
                preds = output.argmax(dim=1)
                test_f1s.append(f1_module(preds, y).item())
                if j > 10:
                    break
            f1 = sum(test_f1s)/len(test_f1s)
            return round(f1, 4)

In [42]:
loss_history = []
max_steps = 20000
log_every_steps = 200
for i, batch in tqdm(enumerate(loader), total=max_steps):
    x = batch['observation'].to(device)
    clss = torch.ones((x.shape[0], 1)).long().to(device)
    y = batch['exist_1'].to(device)
    optimizer.zero_grad()
    output = model(x, clss)
    loss = loss_module(output, y)
    loss_history.append(loss.item())
    loss.backward()
    optimizer.step()
    if i > max_steps:
        break
    if i % log_every_steps == 0 :
        print("step {}".format(i))
        print("loss: {:.2e}".format(sum(loss_history)/len(loss_history)))
        loss_history = []
        train_f1 = f1_module(output.argmax(dim=1), y).item()
        print('train f1: ', round(train_f1,4))
        for test_loader in test_loaders_list:
            test_acc = eval(test_loader['data_loader'])
            print(test_loader['name'], " f1_score: " , test_acc)
        print('------------------------')

  0%|          | 0/20000 [00:00<?, ?it/s]

step 0
loss: 0.731182
train f1:  0.3848
observation_size_4  f1_score:  0.6687
observation_size_16  f1_score:  0.5536
observation_size_32  f1_score:  0.3958
observation_size_64  f1_score:  0.2242
------------------------
step 200
loss: 0.210286
train f1:  1.0
observation_size_4  f1_score:  0.997
observation_size_16  f1_score:  0.9747
observation_size_32  f1_score:  0.6893
observation_size_64  f1_score:  0.4633
------------------------
step 400
loss: 0.007178
train f1:  1.0
observation_size_4  f1_score:  0.9989
observation_size_16  f1_score:  0.9928
observation_size_32  f1_score:  0.7767
observation_size_64  f1_score:  0.5584
------------------------
step 600
loss: 0.003843
train f1:  1.0
observation_size_4  f1_score:  0.9997
observation_size_16  f1_score:  0.9948
observation_size_32  f1_score:  0.7791
observation_size_64  f1_score:  0.5496
------------------------
step 800
loss: 0.002619
train f1:  1.0
observation_size_4  f1_score:  0.9997
observation_size_16  f1_score:  0.9972
observat