In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
import transformers
from transformers import get_linear_schedule_with_warmup
import math
import numpy as np

from load_set import *
import model_training
from model_training import train_bern_model
import coin_i3C_modeling as ci3C
from rnn_modeling import *
from ntm import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BATCH_SIZE = 50
MAX_POSITION_EMBEDDINGS = 20
TEST_POSITION_EMBEDDINGS = 20
VOCAB_SIZE = MAX_POSITION_EMBEDDINGS
NUM_LABELS = 2
HIDDEN_SIZE = 256

#NUM_TRAIN_SAMPLES = 3_500_000 #256 * 380
NUM_WARMUP_STEPS = 2_000_000
#SAMPLE_METHOD = "static-warmup"

NUM_TRAIN_SAMPLES = 50_000
SAMPLE_METHOD = "static"

NUM_TEST_SAMPLES = 1000
LEARNING_RATE = 1e-5
EPS = 1e-8
EPOCHS = 10

In [3]:
NUM_TRAIN_SAMPLES, NUM_TEST_SAMPLES

(50000, 1000)

In [4]:
# bucket-sort, duplicate-string, parity-check, missing-duplicate-string
#TASK = "duplicate-string"
TASK = "missing-duplicate-string"
#TASK = "parity-check"
#TASK = "bucket-sort"

In [5]:
if TASK == "parity-check":
    VOCAB_SIZE = 2
elif TASK == "missing-duplicate-string":
    VOCAB_SIZE = 3

In [6]:
class GenConfig:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

In [7]:
if 0:
    model = NTMForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.0,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [8]:
if 0:
    model = sLSTMForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.0,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [9]:
if 1:
    model = COINForBucketSort(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            layer_norm_eps=1e-12
        )
    )

In [10]:
if 0:
    model = COINForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            layer_norm_eps=1e-12
        )
    )

In [11]:
if 0:
    model = RNNForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [12]:
if 0:
    model = StackRNNForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.5,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [13]:
if 0:
    # B=50
    # T_train=20
    # T_test= >500
    class LSTMForParityCheck(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.n_layers = config.num_hidden_layers
            #self.lstm = nn.LSTM(2, config.hidden_size, batch_first=True, num_layers=self.n_layers, dropout=config.hidden_dropout_prob)
            self.lstm = nn.LSTM(config.hidden_size, config.hidden_size, batch_first=True, num_layers=self.n_layers, dropout=config.hidden_dropout_prob)
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.classifier = nn.Sequential(
                #nn.Dropout(config.hidden_dropout_prob, 
                nn.Linear(config.hidden_size, config.num_labels)
            )
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            logits = F.one_hot(input_ids.long(), self.config.vocab_size).float()
            logits = self.embeddings(logits)
            B, T, C = logits.shape
            hidden = (torch.randn(self.n_layers, B, C, device=logits.device), torch.randn(self.n_layers, B, C, device=logits.device))
            logits, hidden = self.lstm(logits)
            r_logits = self.classifier(logits[:, -1])
            loss = self.loss_fn(r_logits, labels)
            return ci3C.COINOutputClass(
                logits=r_logits,
                loss=loss
            )

    model = LSTMForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.5,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [14]:
if 0:
    class BertForBucketSort(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.bert = transformers.BertModel(config)
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            emb = self.embeddings(F.one_hot(input_ids, self.config.vocab_size).float())
            B, T, C = emb.shape
            logits = self.bert(inputs_embeds=emb, attention_mask=attention_mask).last_hidden_state
            logits = self.lm_head(logits)
            loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
            return ci3C.COINOutputClass(
                logits=logits,
                loss=loss
            )


    model = BertForBucketSort(
        config=transformers.BertConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_attention_heads=1,
        )
    )

In [15]:
if 0:
    class BertForParityCheck(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            #self.bert = transformers.BertModel(config)
            self.bert = transformers.BertForSequenceClassification(config)
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            #print(attention_mask)
            emb = F.one_hot(input_ids, self.config.vocab_size).float()
            emb = self.embeddings(emb)
            logits = self.bert(inputs_embeds=emb, attention_mask=attention_mask).logits
            #logits = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
            #logits = self.classifier(logits[:, 0, :])
            loss = self.loss_fn(logits, labels)
            return ci3C.COINOutputClass(
                logits=logits,
                loss=loss
            )


    model = BertForParityCheck(
        config=transformers.BertConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_attention_heads=1,
            num_labels=NUM_LABELS,
        )
    )

In [16]:
if TASK == "bucket-sort":
    gen_fn = generate_bucket_sort_set
elif TASK == "duplicate-string":
    gen_fn = generate_duplicate_string_set
elif TASK == "parity-check":
    gen_fn = generate_parity_check_set
elif TASK == "missing-duplicate-string":
    gen_fn = generate_missing_duplicate_string_set

#train_buf = generate_uniform_batches(gen_fn, B=BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=NUM_TRAIN_SAMPLES)
#test_buf = generate_uniform_batches(gen_fn, B=BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=NUM_TEST_SAMPLES)
train_buf, test_buf = generate_ch_batches(
    generate_set_fn=gen_fn, 
    B=BATCH_SIZE, 
    T_train=MAX_POSITION_EMBEDDINGS, 
    T_test=TEST_POSITION_EMBEDDINGS, 
    vocab_size=VOCAB_SIZE, 
    num_train_samples=NUM_TRAIN_SAMPLES, 
    num_test_samples=NUM_TEST_SAMPLES,
    sample_method=SAMPLE_METHOD,
    num_warmup_steps=NUM_WARMUP_STEPS,
)

sample schema: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 

In [17]:
#train_buf, test_buf = generate_static_parity_check_set(BATCH_SIZE, MAX_POSITION_EMBEDDINGS, TEST_POSITION_EMBEDDINGS, NUM_TRAIN_SAMPLES, NUM_TEST_SAMPLES)

In [18]:
#BATCH_SCHEMA = ["input_ids", "decoder_input_ids", "attention_mask", "labels"]
#BATCH_SCHEMA = ["input_ids", "labels"]

In [19]:
labels = np.unique(train_buf[0]["labels"]).tolist()
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19}
{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19}
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]


In [20]:
device = torch.device("cuda")# if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_buf) / BATCH_SIZE * EPOCHS
warmup_steps = math.ceil(total_steps * 0.05)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

In [21]:
#torch.autograd.set_detect_anomaly(True)

In [22]:
len(train_buf)

50000

In [23]:
BATCH_SCHEMA = list(train_buf[0].keys())
BATCH_SCHEMA

['input_ids', 'decoder_input_ids', 'labels']

In [24]:
model_training.STRING_BATCH_INDEX = True
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    device,
    nn.CrossEntropyLoss(),
    id2label,
    batch_schema=BATCH_SCHEMA,
    train_dataloader=train_buf,
    test_dataloader=test_buf,
    #vocab_size=VOCAB_SIZE,
    print_status=True,
    train_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
    generic_output_class=True,
    forward_args=["input_ids", "decoder_input_ids", "attention_mask", "labels"],
    chomsky_task=True,

    calc_metrics=True
)



Training...
  Batch 1,000  of  50,000.    Elapsed:  0:00:05, Remaining:  0:04:05.
  Batch 2,000  of  50,000.    Elapsed:  0:00:10, Remaining:  0:04:00.
  Batch 3,000  of  50,000.    Elapsed:  0:00:14, Remaining:  0:03:55.
  Batch 4,000  of  50,000.    Elapsed:  0:00:19, Remaining:  0:03:50.
  Batch 5,000  of  50,000.    Elapsed:  0:00:24, Remaining:  0:03:45.
  Batch 6,000  of  50,000.    Elapsed:  0:00:29, Remaining:  0:03:40.
  Batch 7,000  of  50,000.    Elapsed:  0:00:34, Remaining:  0:03:35.
  Batch 8,000  of  50,000.    Elapsed:  0:00:38, Remaining:  0:03:30.
  Batch 9,000  of  50,000.    Elapsed:  0:00:43, Remaining:  0:03:25.
  Batch 10,000  of  50,000.    Elapsed:  0:00:48, Remaining:  0:03:20.
  Batch 11,000  of  50,000.    Elapsed:  0:00:53, Remaining:  0:03:15.
  Batch 12,000  of  50,000.    Elapsed:  0:00:57, Remaining:  0:03:10.
  Batch 13,000  of  50,000.    Elapsed:  0:01:02, Remaining:  0:03:05.
  Batch 14,000  of  50,000.    Elapsed:  0:01:07, Remaining:  0:03:00.
 