In [1]:
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
import torch.nn as nn
import random
import numpy as np
import math
from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda')

In [3]:
class SymbolicWorldIterableDataset(IterableDataset):

    def __init__(self, num_objects=32, observation_size=8):
        self.num_objects = num_objects
        self.observation_size = observation_size
        pass

    def generate_observation(self):
        observation = torch.randint(0, self.num_objects, (self.observation_size,))
        exist_1 = int(1 in observation)
        return {'observation':observation, 'exist_1':exist_1}

    def __iter__(self):
        while(True):
            yield self.generate_observation()

iterable_dataset = SymbolicWorldIterableDataset()
loader = DataLoader(iterable_dataset, batch_size=1024)

test_loader_size_2 = DataLoader(SymbolicWorldIterableDataset(observation_size=2), batch_size=1024)
test_loader_size_4 = DataLoader(SymbolicWorldIterableDataset(observation_size=4), batch_size=1024)
test_loader_size_16 = DataLoader(SymbolicWorldIterableDataset(observation_size=16), batch_size=1024)
test_loader_size_32 = DataLoader(SymbolicWorldIterableDataset(observation_size=32), batch_size=1024)
test_loader_size_64 = DataLoader(SymbolicWorldIterableDataset(observation_size=64), batch_size=1024)
test_loader_size_128 = DataLoader(SymbolicWorldIterableDataset(observation_size=128), batch_size=1024)

test_loader_nobj_16 = DataLoader(SymbolicWorldIterableDataset(num_objects=16), batch_size=1024)
test_loader_nobj_4 = DataLoader(SymbolicWorldIterableDataset(num_objects=4), batch_size=1024)

if False:
    for a in loader:
        print(a['observation'][0])
        print(a['exist_1'][0])
        break


In [4]:
class MySimpleModel(nn.Module):
    
    def __init__(self):
        super(MySimpleModel, self).__init__()
        d_model = 32
        n_heads = 8
        n_layers = 1
        n_objects = 32
        n_clss = 5
        self.d_model = d_model
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.transformer_layer = torch.nn.TransformerEncoder(self.encoder_layer, n_layers)
        self.token_embedding = torch.nn.Embedding(n_objects, d_model)
        self.clss_embedding = torch.nn.Embedding(n_clss, d_model)
        self.fc1 = torch.nn.Linear(d_model, 2)
        
    def forward(self, x, clss):
        '''
        x : long tensor: [batch, seq_len]
        clss: long tensor: [batch, 1]
        return: tensor: [batch, 2]
        '''
        temp = x
        temp = self.token_embedding(temp)
        clss_embedding = self.clss_embedding(clss)
        temp = torch.cat((clss_embedding, temp), dim=1)
        temp = self.transformer_layer(temp)
        temp = temp[:, 0, :]
        temp = self.fc1(temp)
        return temp

In [5]:
model = MySimpleModel().to(device)
loss_module = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [6]:
def eval(test_loader):
    with torch.no_grad():
            test_accs = []
            for j, batch in enumerate(test_loader):
                x = batch['observation'].to(device)
                clss = torch.ones((x.shape[0], 1)).long().to(device)
                y = batch['exist_1'].to(device)
                output = model(x, clss)
                acc = ((output.argmax(dim=1) == y).sum() / output.shape[0]).item()
                test_accs.append(acc)
                if j > 10:
                    break
            return sum(test_accs)/len(test_accs)

In [7]:
loss_history = []
max_steps = 20000
log_every_steps = 200
for i, batch in tqdm(enumerate(loader), total=max_steps):
    x = batch['observation'].to(device)
    clss = torch.ones((x.shape[0], 1)).long().to(device)
    y = batch['exist_1'].to(device)
    optimizer.zero_grad()
    output = model(x, clss)
    loss = loss_module(output, y)
    loss_history.append(loss.item())
    loss.backward()
    optimizer.step()
    if i > max_steps:
        break
    if i % log_every_steps == 0 :
        print("step {}".format(i))
        print("loss: {:.3f}".format(sum(loss_history)/len(loss_history)))
        loss_history = []
        train_acc = ((output.argmax(dim=1) == y).sum() / output.shape[0]).item()
        print('train acc: ', train_acc)
        print('test acc size 2: {:.3f}'.format(eval(test_loader_size_2)))
        print('test acc size 4: {:.3f}'.format(eval(test_loader_size_4)))
        print('test acc size 16: {:.3f}'.format(eval(test_loader_size_16)))
        print('test acc size 64: {:.3f}'.format(eval(test_loader_size_64)))
        print('test acc size 32: {:.3f}'.format(eval(test_loader_size_32)))
        print('test acc size 128: {:.3f}'.format(eval(test_loader_size_128)))
        print('test acc nojb 4: {:.3f}'.format(eval(test_loader_nobj_4)))
        print('test acc nojb 16: {:.3f}'.format(eval(test_loader_nobj_16)))
        print('------------------------')

  0%|          | 0/20000 [00:00<?, ?it/s]

step 0
loss: 0.634
train acc:  0.7421875
test acc size 2: 0.885
test acc size 4: 0.849
test acc size 16: 0.596
test acc size 64: 0.140
test acc size 32: 0.363
test acc size 128: 0.022
test acc nojb 4: 0.107
test acc nojb 16: 0.575
------------------------
step 200
loss: 0.537
train acc:  0.7783203125
test acc size 2: 0.937
test acc size 4: 0.877
test acc size 16: 0.607
test acc size 64: 0.134
test acc size 32: 0.361
test acc size 128: 0.019
test acc nojb 4: 0.102
test acc nojb 16: 0.587
------------------------
step 400
loss: 0.502
train acc:  0.79296875
test acc size 2: 0.946
test acc size 4: 0.882
test acc size 16: 0.602
test acc size 64: 0.130
test acc size 32: 0.356
test acc size 128: 0.017
test acc nojb 4: 0.149
test acc nojb 16: 0.593
------------------------
step 600
loss: 0.375
train acc:  0.9052734375
test acc size 2: 0.977
test acc size 4: 0.970
test acc size 16: 0.737
test acc size 64: 0.241
test acc size 32: 0.491
test acc size 128: 0.089
test acc nojb 4: 0.812
test acc noj