In [1]:
import torch
from transformers import RobertaModel

from LicGan.models_gan import Generator, Discriminator, gumbel_sigmoid
from LicGan.graph_data import get_loaders, SyntheticGraphDataset
import numpy as np

from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def print_network(model, name, log=None):
    """Print out the network information."""
    num_params = 0
    for p in model.parameters():
        num_params += p.numel()
    # print(model)
    print(name)
    print("The number of parameters: {}".format(num_params))
    if log is not None:
        log.info(model)
        log.info(name)
        log.info("The number of parameters: {}".format(num_params))

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Torch device: ', device)

Torch device:  cuda


In [4]:
log = None
N = 100
z_dim = 8
mha_dim = 768
n_heads = 8
dropout = 0
model_mode = 1
disc_dims = [[128, 128], [512, 768], [512, 256, 128]]
gen_dims = [[128, 256, 768], [512, 512]]
lm_model = 'roberta-base'

In [5]:
G = Generator(N,
            z_dim,
            gen_dims,
            mha_dim,
            n_heads,
            dropout,
            model_mode)
D = Discriminator(N,
                disc_dims, 
                mha_dim,
                n_heads,
                dropout,
                model_mode)

bert_D = RobertaModel.from_pretrained(lm_model)

print_network(G, 'G', log)
print_network(D, 'D', log)
print_network(bert_D, lm_model+'_D', log)

G.to(device)
D.to(device)
bert_D.to(device)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


G
The number of parameters: 5043252
D
The number of parameters: 10376769
roberta-base_D
The number of parameters: 124645632


RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropou

In [6]:
restore_G = './results/100Node/models/100-G.ckpt'
restore_D = './results/100Node/models/100-D.ckpt'
restore_B_D = None

if restore_D:
    D.load_state_dict(torch.load(restore_D, map_location=lambda storage, loc: storage))
if restore_G:
    G.load_state_dict(torch.load(restore_G, map_location=lambda storage, loc: storage))
if restore_B_D:
    bert_D.pooler.load_state_dict(torch.load(restore_B_D, map_location=lambda storage, loc: storage))

In [7]:
max_len = 514
model_name='roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [8]:
description = 'An airport with many planes.'
text = 'A graph representing a remote sensing image with the following description: ' + str(description)

token_text = tokenizer(text, add_special_tokens=True, truncation=False, max_length=max_len, padding='max_length')

In [9]:
# z = torch.from_numpy(z).to(device).float()
ids = torch.from_numpy(np.stack([token_text.input_ids])).to(device)
mask = torch.from_numpy(np.stack([token_text.attention_mask])).to(device)

In [10]:
bert_D_out = bert_D(ids, attention_mask=mask).last_hidden_state[:,:N,:]
bert_G_out = bert_D_out