In [1]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
cd "/content/drive/MyDrive/disentangled_transitions"

/content/drive/MyDrive/disentangled_transitions


In [3]:
!pip install colour
!pip install dm_env



In [4]:
import numpy as np
from data_utils import create_factorized_dataset
from data_utils import make_env
from data_utils import SpriteMaker, StateActionStateDataset
from functools import partial
from modules import *
from utils import *
import torch
import time
import torch.optim as optim
from torch.optim import lr_scheduler
from coda import get_true_flat_mask
from tqdm import tqdm
from sklearn import metrics

import structured_transitions
from structured_transitions import gen_samples_dynamic, TransitionsData, MixtureOfMaskedNetworks, SimpleStackedAttn, MaskedNetwork
from dynamic_scm_discovery import compute_metrics

In [5]:
SEED = 1
np.random.seed(SEED)
BATCH_SIZE = 1000
DATASET_SIZE = 50000
MASK_REGULARIZATION_COEFFICIENT = 0.
WEIGHT_LOSS_COEFFICIENT = 0.
ATTENTION_LOSS_COEFFICIENT = 0.
WEIGHT_DECAY = 0.
dev = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
ground_truth_kwargs = dict(num_sprites=4, seed=SEED, max_episode_length=5000, imagedim=16)
config, env = make_env(**ground_truth_kwargs)
env.action_space.seed(SEED)  # reproduce randomness in action space
sprite_maker = SpriteMaker(partial(make_env, **ground_truth_kwargs))
data, sprites = create_factorized_dataset(env, DATASET_SIZE)

s, a, r, s2 = list(zip(*data))
s = np.array(s)
a = np.array(a)
s2 = np.array(s2)
ground_truth_masks = []
for s_, a_ in tqdm(zip(s, a)): # 50000 iters
 mask = get_true_flat_mask(sprite_maker(s_), config, a_)
 mask = mask[:, :-2]
 ground_truth_masks.append(mask)

zeros = np.zeros((DATASET_SIZE, 2))
a = np.concatenate((zeros, np.array(a)), axis=1).reshape((DATASET_SIZE, 1, -1))
sa = np.concatenate((s, a), axis=1)
ground_truth_masks = np.array(ground_truth_masks)

samples = (
  torch.FloatTensor(sa),
  torch.FloatTensor(s2), 
  torch.FloatTensor(ground_truth_masks)
)

dataset = TransitionsData(samples)
tr = TransitionsData(dataset[:int(len(dataset)*5/6)])
te = TransitionsData(dataset[int(len(dataset)*5/6):])

..................................................

50000it [01:32, 541.33it/s]


In [7]:
train_loader = torch.utils.data.DataLoader(tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)
test_loader  = torch.utils.data.DataLoader(te, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, drop_last=True)

In [8]:
off_diag = np.ones([5, 5])


rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_rec = torch.FloatTensor(rel_rec)
rel_send = torch.FloatTensor(rel_send)


In [12]:
encoder = MLPEncoder(4, 256, 2, 0.0, True)

decoder = SingleStepDecoder(4, 2, 256, 256, 256, 0.0, True)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),
                       lr=0.0005)
scheduler = lr_scheduler.StepLR(optimizer, step_size=200,
                                gamma=0.5)
rel_rec = rel_rec.cuda()
rel_send = rel_send.cuda()
encoder.cuda()
decoder.cuda()

prior = np.array([.999, 0.001])  # TODO: hard coded for now
print("Using prior")
print(prior)
log_prior = torch.FloatTensor(np.log(prior))
log_prior = torch.unsqueeze(log_prior, 0)
log_prior = torch.unsqueeze(log_prior, 0)
log_prior = Variable(log_prior)
log_prior = log_prior.cuda()

def train(epoch):
  t = time.time()
  nll_train = []
  acc_train = []
  kl_train = []
  mse_train = []
  encoder.train()
  decoder.train() 
  for batch_idx, (x, y, m) in enumerate(train_loader):
    x = Variable(x.cuda())
    y = Variable(y.cuda())
    optimizer.zero_grad()

    logits = encoder(x, rel_rec, rel_send)
    edges = gumbel_softmax(logits, tau=0.5, hard=False)
    prob = my_softmax(logits, -1)
    pred = decoder(x, edges, rel_rec, rel_send)

    loss_nll = nll_gaussian(pred, y, 5e-5)

    loss_kl = kl_categorical(prob, log_prior, 5)

    loss = 0.1 * loss_nll + loss_kl
    loss.backward()
    optimizer.step()
    scheduler.step()

    mse_train.append(F.mse_loss(pred, y).item())
    nll_train.append(loss_nll.item())
    kl_train.append(loss_kl.item())
  
  edge = edges[0]
  mask = np.zeros((5, 5))
  index_1 = torch.argmax(rel_rec, dim = 1)
  index_2 = torch.argmax(rel_send, dim = 1)
  _edge = torch.argmax(edge, dim=1)
  for i in range(25):
    mask[index_1[i], index_2[i]] = _edge[i]
  print(mask)

  print('Epoch: {:04d}'.format(epoch),
        'nll_train: {:.10f}'.format(np.mean(nll_train)),
        'kl_train: {:.10f}'.format(np.mean(kl_train)),
        'mse_train: {:.10f}'.format(np.mean(mse_train)),
        # 'acc_train: {:.10f}'.format(np.mean(acc_train)),
        'time: {:.4f}s'.format(time.time() - t))
  return(np.mean(nll_train))
# def train(epoch):
#     t = time.time()
#     nll_train = []
#     acc_train = []
#     kl_train = []
#     mse_train = []

#     encoder.train()
#     decoder.train()
#     for batch_idx, (data) in enumerate(train_loader):

#         data = data.cuda()
#         data = Variable(data)

#         optimizer.zero_grad()

#         logits = encoder(data, rel_rec, rel_send)
        # edges = gumbel_softmax(logits, tau=0.5, hard=False)
        # prob = my_softmax(logits, -1)


        # output = decoder(data, edges, rel_rec, rel_send,
        #                   10)

#         target = data[:, :, 1:, :]

#         loss_nll = nll_gaussian(output, target, 5e-5)


#         loss_kl = kl_categorical_uniform(prob, 5,
#                                           2)

#         loss = loss_nll + loss_kl

#         # acc = edge_accuracy(logits, relations)
#         # acc_train.append(acc)

#         loss.backward()
#         optimizer.step()
#         scheduler.step()

#         mse_train.append(F.mse_loss(output, target).item())
#         nll_train.append(loss_nll.item())
#         kl_train.append(loss_kl.item())

#     print('Epoch: {:04d}'.format(epoch),
#           'nll_train: {:.10f}'.format(np.mean(nll_train)),
#           'kl_train: {:.10f}'.format(np.mean(kl_train)),
#           'mse_train: {:.10f}'.format(np.mean(mse_train)),
#           # 'acc_train: {:.10f}'.format(np.mean(acc_train)),
#           'time: {:.4f}s'.format(time.time() - t))
#     return(np.mean(nll_train))

# def test():
#     acc_test = []
#     nll_test = []
#     kl_test = []
#     mse_test = []
#     tot_mse = 0
#     counter = 0

#     encoder.eval()
#     decoder.eval()
#     i=0
#     for batch_idx, (data) in enumerate(test_loader):
#         data.cuda()
#         data = Variable(data, volatile=True)

#         data_encoder = data[:, :, :49, :].contiguous().cuda()
#         data_decoder = data[:, :, -49:, :].contiguous().cuda()

#         logits = encoder(data_encoder, rel_rec, rel_send)
#         edges = gumbel_softmax(logits, tau=0.5, hard=True)
#         i += 1
#         if i % 100 == 1:
          # edge = edges[0]
          # mask = np.zeros((5, 5))
          # index_1 = torch.argmax(rel_rec, dim = 1)
          # index_2 = torch.argmax(rel_send, dim = 1)
          # _edge = torch.argmax(edge, dim=1)
          # for i in range(20):
          #   mask[index_1[i], index_2[i]] = _edge[i]
          # print(mask)

        

#         prob = my_softmax(logits, -1)

#         output = decoder(data_decoder, edges, rel_rec, rel_send, 1)

#         target = data_decoder[:, :, 1:, :]
#         loss_nll = nll_gaussian(output, target, 5e-5)
#         loss_kl = kl_categorical_uniform(prob, 5, 2)

#         mse_test.append(F.mse_loss(output, target).item())
#         nll_test.append(loss_nll.item())
#         kl_test.append(loss_kl.item())
#     print(len(mse_test))
#     print('--------------------------------')
#     print('--------Testing-----------------')
#     print('--------------------------------')
#     print('nll_test: {:.10f}'.format(np.mean(nll_test)),
#           'kl_test: {:.10f}'.format(np.mean(kl_test)),
#           'mse_test: {:.10f}'.format(np.mean(mse_test)),
#           'acc_test: {:.10f}'.format(np.mean(acc_test)))
for epoch in range(100):
  train(epoch)
# test()

Using factor graph MLP encoder.
Using learned interaction net decoder.
Using prior
[0.999 0.001]


  nn.init.xavier_normal(m.weight.data)
  nn.init.xavier_normal(m.weight.data)
  soft_max_1d = F.softmax(trans_input)


[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
Epoch: 0000 nll_train: 67.8975484895 kl_train: 6.4344193819 mse_train: 0.0016974387 time: 1.9159s
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
Epoch: 0001 nll_train: 60.9418048393 kl_train: 2.8502918220 mse_train: 0.0015235451 time: 1.8865s
[[0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
Epoch: 0002 nll_train: 54.8243586843 kl_train: 2.1461386390 mse_train: 0.0013706090 time: 1.8956s
[[0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 1. 0. 1. 1.]
 [0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 1.]]
Epoch: 0003 nll_train: 50.2570379304 kl_train: 1.6373017212 mse_train: 0.0012564259 time: 1.8906s
[[0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
Epoch: 0004 nll_train: 47.0096140606 kl_train: 1.3754223454 mse_train: 0.0011752404 time: 1.8736s
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 

KeyboardInterrupt: ignored