In [1]:
%load_ext autoreload  
%autoreload 2 

In [49]:
import torchspider
from datasets import load_dataset
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
from torch.utils.data import DataLoader

In [None]:
"""

# select 1000 unlabeled data to label 

selected_unlabeld_data = select_unlabeled_data(all_unlabeled_data, strategy="random", num=1000)
selected_unlabeld_data = select_unlabeled_data(all_unlabeled_data, strategy="uncertainty_sampling", num=1000)
selected_unlabeld_data = select_unlabeled_data(all_unlabeled_data, strategy="optimal_subset", optimal_subset=optimal_subset)

# label these data (in this case, we just get their labels from the HF dataset, because they are already labeled and we were only pretending that they were unlabeled)

selected_labeled_data = label_data(selected_unlabeld_data, hf_data)

# train the model with these newly labeled data and plot them out 

"""

# Data


In [102]:
@dataclass(frozen=True)
class Config:
    max_length: int = 66
    debug: bool = False
    epochs: int = 10
    batch_size: int = 8
    model_name: str = "google/electra-small-discriminator"
    optimizer: str = "adamw"
    loss_func: str = "cross_entropy_loss"
    lr: int = 1e-5
    path: str = "."


config = Config(max_length=66, debug=True)

In [127]:
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")

def preprocess(data):
    data = data.rename_column('label', 'scalar_label')
    data = data.map(lambda x: {'label' : 0 if x['scalar_label'] < 0.5 else 1})

    def tokenize_func(examples): 
        tokenized = tokenizer(
            examples["sentence"], padding="max_length", max_length=config.max_length, truncation=True
        )
        tokenized["label"] = examples["label"]
        return tokenized

    ds = data.map(
        tokenize_func,
        remove_columns=data.column_names,
        batched=True,
    )
    ds.set_format(type="torch")
    return ds

In [128]:
sst2 = load_dataset("sst")

selected_indices = np.random.choice(len(sst2["train"]), replace=False, size=1000)
selected_data = sst2["train"].select(selected_indices)
debug_data = sst2["train"].select(selected_indices[:8])

100%|██████████| 3/3 [00:00<00:00, 603.73it/s]


In [130]:
train_ds = preprocess(selected_data)
valid_ds = preprocess(sst2["validation"])
test_ds = preprocess(sst2["test"])
debug_ds = preprocess(debug_data)

100%|██████████| 1000/1000 [00:00<00:00, 4152.37ex/s]
100%|██████████| 1/1 [00:00<00:00,  3.18ba/s]
100%|██████████| 1101/1101 [00:00<00:00, 10050.82ex/s]
100%|██████████| 2/2 [00:00<00:00, 22.08ba/s]
100%|██████████| 2210/2210 [00:00<00:00, 11649.37ex/s]
100%|██████████| 3/3 [00:00<00:00,  9.61ba/s]
100%|██████████| 8/8 [00:00<00:00, 2607.59ex/s]
100%|██████████| 1/1 [00:00<00:00, 167.31ba/s]


In [132]:
# selected_data["label"]

In [135]:
# train_ds["label"]

In [136]:
train_ds

Dataset({
    features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1000
})

In [137]:
list(map(len, [train_ds, valid_ds, test_ds, debug_ds]))

[1000, 1101, 2210, 8]

In [138]:
train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=config.batch_size, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=config.batch_size, shuffle=False)
debug_dl = DataLoader(debug_ds, batch_size=config.batch_size, shuffle=True)

In [139]:
len(train_dl), len(valid_dl), len(test_dl)

(125, 138, 277)

In [140]:
for batch in train_dl:
    # print(len(batch["input_ids"]))
    # print(tokenizer.decode(batch["input_ids"][0]))
    print(batch.keys())
    break


dict_keys(['label', 'input_ids', 'token_type_ids', 'attention_mask'])


In [141]:
dls = DataLoaderGroup(train_dl, valid_dl, test_dl) if not config.debug else DataLoaderGroup(debug_dl, debug_dl, test_dl)

In [142]:
len(dls.train_dl), len(debug_dl), train_ds[:config.batch_size*2]["input_ids"].shape

(1, 1, torch.Size([16, 66]))

# Training


In [143]:
from torchspider import *
import torch

In [144]:
cbs = (
    [CudaCallback(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")), TrackLoss(), Debugger()]
    if config.debug
    else [
        CudaCallback(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")),
        WandbTrackAndSave("beautify", "beautify_bullet"),
    ]
)



In [145]:
model = AutoModelForSequenceClassification.from_pretrained(config.model_name)

Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier

In [146]:
learner = Learner(model, 
                  dls, 
                  config, 
                  cbs=cbs)

saved dls successfully!




In [147]:
learner.fit(config.epochs)

epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]

AttributeError: 'dict' object has no attribute 'size'

In [62]:
len(dls.train_dl)

1

In [80]:
# for i, batch in enumerate(learner.dls.train_dl):
    # print(len(batch))
for i, batch in enumerate(dls.train_dl):
    print(batch)

{'input_ids': tensor([[  101,  1012,  1012,  1012,  2065,  2017,  1005,  2128,  2074,  1999,
          1996,  6888,  2005,  1037,  4569,  1011,  1011,  2021,  2919,  1011,
          1011,  3185,  1010,  2017,  2453,  2215,  2000,  4608, 29526,  2004,
          1037, 13523,  3170,  2063,  1012,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  1996,  2143,  4152,  2485,  2000,  1996,  9610, 25370,  1996,
          2168,  2126,  2204,  8095,  2106,  1010,  2007,  1037,  3809, 13128,
         11752,  1010,  4847,  1998, 12242,  1012,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0, 