In [1]:
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
import torch.nn as nn
import random
import numpy as np
import math
from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda')

In [4]:
class SymbolicWorldIterableDataset(IterableDataset):

    def __init__(self, num_objects=32, observation_size=8):
        self.num_objects = num_objects
        self.observation_size = observation_size
        pass

    def generate_observation(self):
        observation = torch.randint(0, self.num_objects, (self.observation_size,))
        exist_1 = int(1 in observation)
        return {'observation':observation, 'exist_1':exist_1}

    def __iter__(self):
        while(True):
            yield self.generate_observation()

iterable_dataset = SymbolicWorldIterableDataset()
loader = DataLoader(iterable_dataset, batch_size=1024)

if False:
    for a in loader:
        print(a['observation'][0])
        print(a['exist_1'][0])
        break


In [5]:
class MySimpleModel(nn.Module):
    
    def __init__(self):
        super(MySimpleModel, self).__init__()
        d_model = 32
        n_heads = 8
        n_layers = 2
        n_objects = 32
        n_clss = 5
        self.d_model = d_model
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.transformer_layer = torch.nn.TransformerEncoder(self.encoder_layer, n_layers)
        self.token_embedding = torch.nn.Embedding(n_objects, d_model)
        self.clss_embedding = torch.nn.Embedding(n_clss, d_model)
        self.fc1 = torch.nn.Linear(d_model, 2)
        
    def forward(self, x, clss):
        '''
        x : long tensor: [batch, seq_len]
        clss: long tensor: [batch, 1]
        return: tensor: [batch, 2]
        '''
        temp = x
        temp = self.token_embedding(temp)
        clss_embedding = self.clss_embedding(clss)
        temp = torch.cat((clss_embedding, temp), dim=1)
        temp = self.transformer_layer(temp)
        temp = temp[:, 0, :]
        temp = self.fc1(temp)
        return temp

In [6]:
model = MySimpleModel().to(device)
loss_module = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [7]:
loss_history = []
max_steps = 1000
log_every_steps = 100
for i, batch in tqdm(enumerate(loader), total=max_steps):
    x = batch['observation'].to(device)
    clss = torch.ones((x.shape[0], 1)).long().to(device)
    y = batch['exist_1'].to(device)
    optimizer.zero_grad()
    output = model(x, clss)
    loss = loss_module(output, y)
    loss_history.append(loss.item())
    loss.backward()
    optimizer.step()
    if i > max_steps:
        break
    if i % log_every_steps == 0 :
        print("{:.3f}".format(sum(loss_history)/len(loss_history)))
        loss_history = []
        acc = (output.argmax(dim=1) == y).sum() / output.shape[0]
        print('acc: ', acc)
        print('number of ones:', output.argmax(dim=1).sum())
        print('random chance:', y.sum()/output.shape[0])
        print('------------------------')

  0%|          | 0/1000 [00:00<?, ?it/s]

1.223
acc:  tensor(0.2441, device='cuda:0')
number of ones: tensor(1024, device='cuda:0')
random chance: tensor(0.2441, device='cuda:0')
------------------------
0.675
acc:  tensor(0.7812, device='cuda:0')
number of ones: tensor(0, device='cuda:0')
random chance: tensor(0.2188, device='cuda:0')
------------------------
0.532
acc:  tensor(0.7803, device='cuda:0')
number of ones: tensor(0, device='cuda:0')
random chance: tensor(0.2197, device='cuda:0')
------------------------
0.506
acc:  tensor(0.7627, device='cuda:0')
number of ones: tensor(0, device='cuda:0')
random chance: tensor(0.2373, device='cuda:0')
------------------------
0.434
acc:  tensor(0.8369, device='cuda:0')
number of ones: tensor(74, device='cuda:0')
random chance: tensor(0.2354, device='cuda:0')
------------------------
0.177
acc:  tensor(0.9824, device='cuda:0')
number of ones: tensor(216, device='cuda:0')
random chance: tensor(0.2285, device='cuda:0')
------------------------
0.066
acc:  tensor(0.9922, device='cuda: