In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
import transformers
from transformers import get_linear_schedule_with_warmup
import math
import numpy as np

from load_set import *
import model_training
from model_training import train_bern_model
import coin_i3C_modeling as ci3C
from rnn_modeling import *
from ntm import *
from lnn_modeling import *
from pcoin_modeling import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BATCH_SIZE = 50
MAX_POSITION_EMBEDDINGS = 20
TEST_POSITION_EMBEDDINGS = 200
VOCAB_SIZE = 11#MAX_POSITION_EMBEDDINGS
NUM_LABELS = 5
HIDDEN_SIZE = 256

NUM_TRAIN_SAMPLES = 2_500_000 #256 * 380
NUM_WARMUP_STEPS = 1_500_000
SAMPLE_METHOD = "static-warmup"

#NUM_TRAIN_SAMPLES = 2_500_000
#NUM_WARMUP_STEPS = 1_500_000
#SAMPLE_METHOD = "linspace"

#SAMPLE_METHOD = "uniform"

#NUM_TRAIN_SAMPLES = 50_000
#SAMPLE_METHOD = "static"

NUM_TEST_SAMPLES = 1000
LEARNING_RATE = 1e-4#1e-5
EPS = 1e-8
EPOCHS = 10

In [3]:
#MAX_POSITION_EMBEDDINGS = max(MAX_POSITION_EMBEDDINGS, TEST_POSITION_EMBEDDINGS) # hotfix

In [4]:
NUM_TRAIN_SAMPLES, NUM_TEST_SAMPLES

(2500000, 1000)

In [5]:
# bucket-sort, duplicate-string, parity-check, missing-duplicate-string, binary-addition, binary-sqrt, modular-arithmetic
#TASK = "duplicate-string"
#TASK = "missing-duplicate-string"
TASK = "parity-check"
#TASK = "bucket-sort"
#TASK = "binary-addition"
#TASK = "binary-sqrt"
#TASK = "modular-arithmetic"
#TASK = "modular-arithmetic-brackets"

In [6]:
if TASK in ("parity-check", "binary-sqrt"):
    VOCAB_SIZE = 2
    NUM_LABELS = 2
elif TASK in ("missing-duplicate-string", "binary-addition"):
    VOCAB_SIZE = 3
    NUM_LABELS = 2

In [7]:
class GenConfig:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

In [8]:
if 1:
    model = pCOINForSequenceClassification(
        config=pCOINConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            one_hot_encoding=True,
            chunk_schema=["1", "T / 2"]
        )
    )

In [9]:
if 0:
    model = NTMForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.0,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [10]:
if 0:
    # hidden_size = units = backbone_units
    model = CfCForParityCheck(
        CfCConfig(
            input_size = 256,
            hidden_size = 256,
            sparsity_mask = None,
            backbone_layers = 1,
            backbone_units = 256,
            backbone_dropout = 0.0,
            mode = "default",
            units = 256,
            proj_size = 256,
            wiring = None,
            backbone_activation = "tanh",
            mixed_memory = False,
            vocab_size=VOCAB_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [11]:
if 0:
    model = sLSTMForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.0,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [12]:
if 0:
    model = COINForBucketSort(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            layer_norm_eps=1e-12,
            carry_over_S=True,
        )
    )

In [13]:
if 0:
    l_size = 32
    model = CfCForBucketSort(
        config=CfCConfig(
            vocab_size=VOCAB_SIZE,
            #max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            #num_hidden_layers=1,
            input_size=l_size,
            hidden_size=l_size,
            units=l_size,
            backbone_units=l_size,
            #intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            #layer_norm_eps=1e-12
        )
    )

In [14]:
if 0:
    l_size = 32
    model = LTCForBucketSort(
        config=LTCConfig(
            vocab_size=VOCAB_SIZE,
            #max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            #num_hidden_layers=1,
            input_size=l_size,
            hidden_size=l_size,
            units=l_size,
            #intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            #layer_norm_eps=1e-12
        )
    )

In [15]:
if 0:
    model = COINForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            intermediate_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
            layer_norm_eps=1e-12,
            carry_over_S=False,
            chunk_schema=[ lambda T: 1]
        )
    )

In [16]:
if 0:
    model = RNNForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.1,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [17]:
if 0:
    model = StackRNNForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.5,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [18]:
if 0:
    # B=50
    # T_train=20
    # T_test= >500
    class LSTMForParityCheck(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.n_layers = config.num_hidden_layers
            #self.lstm = nn.LSTM(2, config.hidden_size, batch_first=True, num_layers=self.n_layers, dropout=config.hidden_dropout_prob)
            self.lstm = nn.LSTM(config.hidden_size, config.hidden_size, batch_first=True, num_layers=self.n_layers, dropout=config.hidden_dropout_prob)
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.classifier = nn.Sequential(
                #nn.Dropout(config.hidden_dropout_prob, 
                nn.Linear(config.hidden_size, config.num_labels)
            )
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            logits = F.one_hot(input_ids.long(), self.config.vocab_size).float()
            logits = self.embeddings(logits)
            B, T, C = logits.shape
            hidden = (torch.randn(self.n_layers, B, C, device=logits.device), torch.randn(self.n_layers, B, C, device=logits.device))
            logits, hidden = self.lstm(logits)
            r_logits = self.classifier(logits[:, -1])
            loss = self.loss_fn(r_logits, labels)
            return ci3C.COINOutputClass(
                logits=r_logits,
                loss=loss
            )

    model = LSTMForParityCheck(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            hidden_dropout_prob=0.5,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [19]:
if 0:
    class BertForBucketSort(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.bert = transformers.BertModel(config)
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            emb = self.embeddings(F.one_hot(input_ids.long(), self.config.vocab_size).float())
            #print(input_ids.shape, emb.shape)
            #emb = F.one_hot(input_ids.long(), self.config.vocab_size).float()
            B, T, C = emb.shape
            logits = self.bert(inputs_embeds=emb, attention_mask=attention_mask).last_hidden_state
            
            #logits = logits[:, -math.ceil(T / 2):]
            
            logits = self.lm_head(logits)
            labels = labels.long()
            
            loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
            return ci3C.COINOutputClass(
                logits=logits,
                loss=loss
            )


    model = BertForBucketSort(
        config=transformers.BertConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=max(MAX_POSITION_EMBEDDINGS, TEST_POSITION_EMBEDDINGS)+1,
            num_hidden_layers=2,
            hidden_size=HIDDEN_SIZE,
            num_attention_heads=1,
        )
    )

In [20]:
if 0:
    class BertForModularArithmetic(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.bert = transformers.BertForSequenceClassification(config)
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            emb = F.one_hot(input_ids.long(), self.config.vocab_size).float()
            emb = self.embeddings(emb)
            logits = self.bert(inputs_embeds=emb, attention_mask=attention_mask).logits
            #logits = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
            #logits = self.classifier(logits[:, 0, :])
            
            loss = self.loss_fn(logits, labels)
            return ci3C.COINOutputClass(
                logits=logits,
                loss=loss
            )

    model = BertForModularArithmetic(
        config=transformers.BertConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_attention_heads=1,
            num_labels=5,
        )
    )

In [21]:
if 0:
    class BertForParityCheck(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            #self.bert = transformers.BertModel(config)
            self.bert = transformers.BertForSequenceClassification(config)
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            emb = F.one_hot(input_ids.long(), self.config.vocab_size).float()
            emb = self.embeddings(emb)
            logits = self.bert(inputs_embeds=emb, attention_mask=attention_mask).logits
            #logits = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
            #logits = self.classifier(logits[:, 0, :])
            
            loss = self.loss_fn(logits, labels)
            return ci3C.COINOutputClass(
                logits=logits,
                loss=loss
            )


    model = BertForParityCheck(
        config=transformers.BertConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_attention_heads=1,
            num_labels=NUM_LABELS,
        )
    )

In [22]:
if 0:
    class LSTMForSequenceClassification(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.lstm = nn.LSTM(
                config.hidden_size, 
                config.hidden_size, 
                config.num_hidden_layers, 
                batch_first=True,
                dropout=config.hidden_dropout_prob
            )
            self.embeddings = nn.Linear(config.vocab_size, config.hidden_size)
            self.classifier = nn.Sequential(
                #nn.Linear(config.hidden_size, config.hidden_size),
                #nn.Tanh(),
                nn.Linear(config.hidden_size, config.num_labels)
            )
            self.loss_fn = nn.CrossEntropyLoss()

        def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            emb = F.one_hot(input_ids.long(), self.config.vocab_size).float()
            emb = self.embeddings(emb)
            logits, (h_n, c_n) = self.lstm(emb)
            h_n = h_n.transpose(0, 1)
            c_n = c_n.transpose(0, 1)
            logits = self.classifier(logits[:, 0, :])
            #logits = self.classifier(h_n)
            loss = self.loss_fn(logits, labels)
            return Output(
                logits=logits,
                loss=loss
            )

    model = LSTMForSequenceClassification(
        config=RNNConfig(
            vocab_size=VOCAB_SIZE,
            max_position_embeddings=MAX_POSITION_EMBEDDINGS,
            num_hidden_layers=1,
            hidden_size=HIDDEN_SIZE,
            num_labels=NUM_LABELS,
        )
    )

In [23]:
accuracy_mask = None
if TASK == "bucket-sort":
    gen_fn = generate_bucket_sort_set
elif TASK == "duplicate-string":
    gen_fn = generate_duplicate_string_set
elif TASK == "parity-check":
    gen_fn = generate_parity_check_set
elif TASK == "missing-duplicate-string":
    gen_fn = generate_missing_duplicate_string_set
elif TASK == "binary-addition":
    gen_fn = generate_binary_addition_set
    accuracy_mask = binary_addition_mask
elif TASK == "binary-sqrt":
    gen_fn = generate_binary_sqrt_set
    accuracy_mask = binary_sqrt_mask
elif TASK == "modular-arithmetic":
    gen_fn = generate_modular_arithmetic_set
    #gen_fn = generate_str_modular_arithmetic_set
elif TASK == "modular-arithmetic-brackets":
    gen_fn = generate_modular_arithmetic_brackets_set

#train_buf = generate_uniform_batches(gen_fn, B=BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=NUM_TRAIN_SAMPLES)
#test_buf = generate_uniform_batches(gen_fn, B=BATCH_SIZE, T=MAX_POSITION_EMBEDDINGS, vocab_size=VOCAB_SIZE, num_samples=NUM_TEST_SAMPLES)
train_buf, test_buf = generate_ch_batches(
    generate_set_fn=gen_fn, 
    B=BATCH_SIZE, 
    T_train=MAX_POSITION_EMBEDDINGS, 
    T_test=TEST_POSITION_EMBEDDINGS, 
    vocab_size=VOCAB_SIZE, 
    num_train_samples=NUM_TRAIN_SAMPLES, 
    num_test_samples=NUM_TEST_SAMPLES,
    sample_method=SAMPLE_METHOD,
    num_warmup_steps=NUM_WARMUP_STEPS,
)

sample schema: tensor([ 1,  1,  1,  ..., 20, 20, 20])


In [24]:
#train_buf, test_buf = generate_static_parity_check_set(BATCH_SIZE, MAX_POSITION_EMBEDDINGS, TEST_POSITION_EMBEDDINGS, NUM_TRAIN_SAMPLES, NUM_TEST_SAMPLES)

In [25]:
#BATCH_SCHEMA = ["input_ids", "decoder_input_ids", "attention_mask", "labels"]
#BATCH_SCHEMA = ["input_ids", "labels"]

In [26]:
labels = np.unique(train_buf[0]["labels"]).tolist()
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(label2id)
print(id2label)
print(labels)

{0: 0, 1: 1}
{0: 0, 1: 1}
[0, 1]


In [27]:
device = torch.device("cuda")# if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPS)

total_steps = len(train_buf) / BATCH_SIZE * EPOCHS
#warmup_steps = math.ceil(total_steps * 0.05)
warmup_steps = math.ceil(total_steps * 0.1)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps)

In [28]:
#torch.autograd.set_detect_anomaly(True)

In [29]:
len(train_buf)

50000

In [30]:
BATCH_SCHEMA = list(train_buf[0].keys())
BATCH_SCHEMA

['input_ids', 'decoder_input_ids', 'attention_mask', 'labels']

In [31]:
torch.set_printoptions(threshold=100_000_000)
train_buf[-1]

{'input_ids': tensor([[1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0],
         [0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1],
         [0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0],
         [1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1],
         [0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1],
         [0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1],
         [1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0],
         [1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1],
         [1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0],
         [0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 

In [32]:
model_training.STRING_BATCH_INDEX = True
stats = train_bern_model(
    model,
    optimizer,
    scheduler,
    EPOCHS,
    device,
    nn.CrossEntropyLoss(),
    id2label,
    batch_schema=BATCH_SCHEMA,
    train_dataloader=train_buf,
    test_dataloader=test_buf,
    #vocab_size=VOCAB_SIZE,
    print_status=True,
    train_batch_size=BATCH_SIZE,
    test_batch_size=BATCH_SIZE,
    generic_output_class=True,
    forward_args=["input_ids", "decoder_input_ids", "attention_mask", "labels"],
    chomsky_task=True,
    accuracy_mask=accuracy_mask,

    calc_metrics=True
)



Training...
  Batch 1,000  of  50,000.    Elapsed:  0:00:02, Remaining:  0:01:38.
  Batch 2,000  of  50,000.    Elapsed:  0:00:05, Remaining:  0:01:36.
  Batch 3,000  of  50,000.    Elapsed:  0:00:07, Remaining:  0:02:21.
  Batch 4,000  of  50,000.    Elapsed:  0:00:11, Remaining:  0:02:18.
  Batch 5,000  of  50,000.    Elapsed:  0:00:14, Remaining:  0:02:15.
  Batch 6,000  of  50,000.    Elapsed:  0:00:18, Remaining:  0:02:56.
  Batch 7,000  of  50,000.    Elapsed:  0:00:22, Remaining:  0:02:52.
  Batch 8,000  of  50,000.    Elapsed:  0:00:26, Remaining:  0:02:48.
  Batch 9,000  of  50,000.    Elapsed:  0:00:31, Remaining:  0:03:25.
  Batch 10,000  of  50,000.    Elapsed:  0:00:36, Remaining:  0:03:20.
  Batch 11,000  of  50,000.    Elapsed:  0:00:41, Remaining:  0:03:15.
  Batch 12,000  of  50,000.    Elapsed:  0:00:46, Remaining:  0:03:48.
  Batch 13,000  of  50,000.    Elapsed:  0:00:52, Remaining:  0:03:42.
  Batch 14,000  of  50,000.    Elapsed:  0:00:58, Remaining:  0:03:36.
 

In [None]:
# .88 / .78 for lr=1e-4, B=15, T=(20, 20), 3_500_000 2_500_000 static-warmup