In [1]:
import os, time
from collections import defaultdict

import matplotlib.pyplot as plt
import omegaconf, torch

from src import utils

In [2]:
def train(opt, model, optimizer):
    start_time = time.time()
    train_loader = utils.get_data(opt, "train")
    num_steps_per_epoch = len(train_loader)

    for epoch in range(opt.training.epochs):
        train_results = defaultdict(float)
        optimizer = utils.update_learning_rate(optimizer, opt, epoch)

        for inputs, labels in train_loader:
            inputs, labels = utils.preprocess_inputs(opt, inputs, labels)

            optimizer.zero_grad()

            scalar_outputs = model(inputs, labels)
            scalar_outputs["Loss"].backward()

            optimizer.step()

            train_results = utils.log_results(
                train_results, scalar_outputs, num_steps_per_epoch
            )

        utils.print_results("train", time.time() - start_time, train_results, epoch)
        start_time = time.time()

        # Validate.
        if epoch % opt.training.val_idx == 0 and opt.training.val_idx != -1:
            validate_or_test(opt, model, "val", epoch=epoch)

    return model


def validate_or_test(opt, model, partition, epoch=None):
    test_time = time.time()
    test_results = defaultdict(float)

    data_loader = utils.get_data(opt, partition)
    num_steps_per_epoch = len(data_loader)

    model.eval()
    print(partition)
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = utils.preprocess_inputs(opt, inputs, labels)

            scalar_outputs = model.forward_downstream_classification_model(
                inputs, labels
            )
            test_results = utils.log_results(
                test_results, scalar_outputs, num_steps_per_epoch
            )

    utils.print_results(partition, time.time() - test_time, test_results, epoch=epoch)
    model.train()


In [3]:
cfg = omegaconf.OmegaConf.load("config.yaml")
opt = utils.parse_args(cfg)

model, optimizer = utils.get_model_and_optimizer(opt)
model_file = "text_model.pt"
# If model exists, load it. Else, train it.
if os.path.isfile(model_file):
# if False:
    print("Loading model...")
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
else:
    print("Training model...")
    model = train(opt, model, optimizer)
    torch.save(model.state_dict(), model_file)    

# print("Testing model...")
# validate_or_test(opt, model, "test")

seed: 42
device: cpu
input:
  path: datasets
  batch_size: 128
  dataset: shakespeare
  mnist:
    encode_label: false
  shakespeare:
    sample_len: 64
    num_classes: 65
model:
  peer_normalization: 0
  momentum: 0.9
  hidden_dim: 256
  num_layers: 4
training:
  epochs: 1
  learning_rate: 0.001
  weight_decay: 0.0003
  momentum: 0.9
  downstream_learning_rate: 0.01
  downstream_weight_decay: 0.003
  val_idx: -1
  final_test: false
hydra:
  run:
    dir: logs

FF_model(
  (model): ModuleList(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1-3): 3 x Linear(in_features=256, out_features=256, bias=True)
  )
  (ff_loss): BCEWithLogitsLoss()
  (linear_classifier): Sequential(
    (0): Linear(in_features=768, out_features=65, bias=False)
  )
  (classification_loss): CrossEntropyLoss()
) 

Loading model...


In [4]:
import torch.nn as nn

@torch.no_grad()
def generate(self, idx, max_new_tokens=100, temperature=0.2, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        sample_len = self.opt.input.shakespeare.sample_len
        idx_cond = idx if idx.size(1) <= sample_len else idx[:, -sample_len:]
        # forward the model to get the logits for the index in the sequence
        logits = self({'neutral': idx_cond})["logits"]
        # pluck the logits at the final step and scale by desired temperature
        logits = logits / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[-1]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = nn.functional.softmax(logits, dim=-1)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx

train_loader = utils.get_data(opt, "train")

# Generate.
context = """First Citizen:
Before we proceed any further, hear me speak. No """
tokens = torch.Tensor([train_loader.dataset.stoi[x] for x in context]).unsqueeze(0).to(opt.device)
generated = generate(model, tokens)
print("".join([train_loader.dataset.itos[x.item()] for x in generated]))

AttributeError: 'DataLoader' object has no attribute 'itos'

In [None]:
train_loader.dataset