In [None]:
!pip freeze | grep 'torch' -E
!pip install datasets

torch @ https://download.pytorch.org/whl/cu121/torch-2.3.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=0a12aa9aa6bc442dff8823ac8b48d991fd0771562eaa38593f9c8196d65f7007
torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.3.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=38b49393f8c322dcaa29d19e5acbf5a0b1978cf1b719445ab670f1fb486e3aa6
torchsummary==1.5.1
torchtext==0.18.0
torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.18.0%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=13e1b48dc5ce41ccb8100ab3dd26fdf31d8f1e904ecf2865ac524493013d0df5


In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel
import datasets as hfdata
import pathlib as Path
import re
import time
from tqdm import tqdm

## Prep Data - `AGNEWS`

In [None]:
agnews_hf = hfdata.load_dataset('ag_news')
agnews_hf['train'], agnews_hf['test']

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


(Dataset({
     features: ['text', 'label'],
     num_rows: 120000
 }),
 Dataset({
     features: ['text', 'label'],
     num_rows: 7600
 }))

In [None]:
agnews_hf['train']['text'][:5]

["Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.',
 "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.",
 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.',
 'Oil prices soar to all-time record, posing new menace to US economy (A

In [None]:
s = agnews_hf['train']['text'][1]
def get_description_without_title(strings):
    """
    Sequences start with <text> (source) source - <actual_article>
    so I only want to extract <actual_article>.

    Also new lines seem to be added with a back slash \
    so I split on those and join with a whitespace.
    """
    pattern = re.compile('\([A-Za-z]+\) [A-Za-z]+ - ')
    return [' '.join(s[_.span()[-1]:].split('\\')) for _ in re.finditer(pattern=pattern, string=s) for s in strings]
get_description_without_title([s])

['Private investment firm Carlyle Group, which has a reputation for making well-timed and occasionally controversial plays in the defense industry, has quietly placed its bets on another part of the market.']

In [None]:
# get lists of train and test data;
train_texts = get_description_without_title(agnews_hf['train']['text'])
train_labels = agnews_hf['train']['label']
test_texts = get_description_without_title(agnews_hf['test']['text'])
test_labels = agnews_hf['test']['label']
assert len(train_texts) - sum(bool(_) for _ in train_texts) == 0
assert len(test_texts) - sum(bool(_) for _ in test_texts) == 0

In [None]:
max(len(_.split(' ')) for _ in train_texts), max(len(_.split(' ')) for _ in test_texts)

(185, 141)

In [None]:
class MyDataset(Dataset):
    def __init__(self, x, y: list[int]):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        """Get (text, label) tuple."""
        return self.x[idx], self.y[idx]

In [None]:
train_dataset = MyDataset(train_texts, train_labels)
test_dataset = MyDataset(test_texts, test_labels)

# get loaders;
def get_loaders(train_dataset, test_dataset, batch_size, num_workers=1):
    train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True, batch_size=batch_size, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False, batch_size=batch_size, num_workers=num_workers)
    return train_loader, test_loader

## Setup tokenizer and model;

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('google-bert/bert-base-uncased')

In [None]:
# default pos embed seems to be 512, but I'll truncate to 256;
# token id 101 is [CLS], 102 is [SEP] and 0 is [PAD].
sentences = ['hello there', 'Hello there', 'Hi Kenobi']  # check if hello and Hello get the same tokens;
tokenizer(sentences, padding='longest', return_tensors='pt', truncation='longest_first', max_length=256)

{'input_ids': tensor([[  101,  7592,  2045,   102,     0,     0],
        [  101,  7592,  2045,   102,     0,     0],
        [  101,  7632,  6358, 16429,  2072,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1]])}

### Load model;

In [None]:
bert = BertModel.from_pretrained('google-bert/bert-base-uncased')
bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [None]:
# Freeze bert weights;
bert.requires_grad_(False)
assert not any(p.requires_grad for p in bert.parameters())

### Setup model architecture;

In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # add first hidden layer;
        self.net = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
def model_forward(x, tokenizer, bert, mlp, device, x_is_cls: bool):
    if x_is_cls:
        cls_embed = x.to(device)
    else:
        x = tokenizer(
            x,
            padding='longest',
            return_tensors='pt',
            truncation='longest_first',
            max_length=200
        )
        x = x.to(device)
        cls_embed = bert(**x)['pooler_output']
    return mlp(cls_embed)

In [None]:
def train_one_epoch(train_loader, tokenizer, bert, mlp, device, optim, x_is_cls=False):
    bert.eval()  # bert itself is not trained;
    mlp.train()
    loader_len = len(train_loader)
    log_every = 1 if len(train_loader) < 50 else 50
    tenth = max(1, loader_len // 10)
    avg_instance_loss = 0
    n = 0
    for i, (x, y) in enumerate(train_loader):
        y = y.to(device)
        optim.zero_grad()
        out = model_forward(
            x, tokenizer, bert, mlp, device, x_is_cls=x_is_cls
        )
        loss = nn.functional.cross_entropy(out, target=y)

        # track losses;
        n += len(y)
        avg_instance_loss = avg_instance_loss + (loss.item() - avg_instance_loss) * len(y) / n
        loss.backward()
        # if (i + 1) % log_every == 0:
            # wandb.log('train_loss_step', loss.detach().item())
        # if (i + 1) % tenth == 0:
        #     print(f"Progress: {(i+1) / loader_len * 100:.3f}% done")
        optim.step()
    return avg_instance_loss

In [None]:
def val_one_epoch(val_loader, tokenizer, bert, mlp, device, x_is_cls=False):
    bert.eval()  # bert itself is not trained;
    mlp.eval()
    loader_len = len(val_loader)
    log_every = 1 if len(val_loader) < 50 else 50
    tenth = max(1, loader_len // 10)
    avg_instance_loss = 0
    n = 0
    avg_acc = 0
    for i, (x, y) in enumerate(val_loader):
        y = y.to(device)
        with torch.no_grad():
            out = model_forward(
                x, tokenizer, bert, mlp, device, x_is_cls=x_is_cls
            )

            # loss calculation;
            loss = nn.functional.cross_entropy(out, target=y)
            n += len(y)
            avg_instance_loss = avg_instance_loss + (loss.item() - avg_instance_loss) * len(y) / n

            # accuracy calculation
            acc = (out.argmax(-1) == y).float().mean().item()
            avg_acc = avg_acc + (acc - avg_acc) * len(y) / n

            # if (i + 1) % log_every == 0:
                # wandb.log('val_loss_step', loss.item())
            # if (i + 1) % tenth == 0:
            #     print(f"Progress: {(i+1) / loader_len * 100:.3f}% done")
    # wandb.log('val_loss_epoch', avg_instance_loss)
    return avg_instance_loss, avg_acc

In [None]:
def train_for_k_epochs(num_epochs, bert, mlp, tokenizer, optim, train_loader, val_loader, epoch_offset=0, x_is_cls=False):
    instance_train_epoch_losses = []
    instance_val_epoch_losses = []
    instance_acc = []
    for t in range(0+epoch_offset, num_epochs+epoch_offset):
        print(f"Epoch {t}:\n{'-'*20}")
        train_loss = train_one_epoch(
            train_loader, tokenizer, bert, mlp, device, optim, x_is_cls=x_is_cls
        )

        val_loss, val_acc = val_one_epoch(
            val_loader, tokenizer, bert, mlp, device, x_is_cls=x_is_cls
        )
        print(f"{train_loss=:.5f}")
        print(f"{val_acc=:.3f}\t{val_loss=:.5f}")
        instance_train_epoch_losses.append(train_loss)
        instance_val_epoch_losses.append(val_loss)
        instance_acc.append(val_acc)
    return instance_train_epoch_losses, instance_val_epoch_losses, instance_acc

### Start Training;

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)
# setup predictors;
mlp = MLP(
    in_dim=768,
    out_dim=4,
).to(device)
optim = torch.optim.Adam(mlp.parameters(), lr=1e-3)
bert = bert.to(device)

### Since bert is fixed, just extract cls embeds;

In [None]:
def get_cls_embeds(loader, tokenizer, bert, device):
    bert.eval()
    assert not any(p.requires_grad for p in bert.parameters())
    cls_tokens, ys = [], []
    loader_len = len(loader)
    tenth = max(1, loader_len // 10)
    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            x = tokenizer(
                x,
                padding='longest',
                return_tensors='pt',
                truncation='longest_first',
                max_length=200
            )
            x = x.to(device)
            cls_embed = bert(**x)['pooler_output']
            cls_tokens.append(cls_embed.to('cpu'))
            ys.append(y)
            if (i + 1) % tenth == 0:
                print(f"Progress: {(i+1) / loader_len * 100:.2f}% done")
    return torch.cat(cls_tokens, 0), torch.cat(ys, 0)

In [None]:
# setup data loaders;
BATCH_SIZE = 64
# train_loader, test_loader = get_loaders(
#     train_dataset, test_dataset,
#     batch_size=BATCH_SIZE, num_workers=1
# )
train_loader = DataLoader(
    train_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=1
)
test_loader = DataLoader(
    test_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=1
)
train_x, train_y = get_cls_embeds(train_loader, tokenizer, bert, device)
print('\n')
test_x, test_y = get_cls_embeds(test_loader, tokenizer, bert, device)

Progress: 9.97% done
Progress: 19.95% done
Progress: 29.92% done
Progress: 39.89% done
Progress: 49.87% done
Progress: 59.84% done
Progress: 69.81% done
Progress: 79.79% done
Progress: 89.76% done
Progress: 99.73% done


Progress: 9.24% done
Progress: 18.49% done
Progress: 27.73% done
Progress: 36.97% done
Progress: 46.22% done
Progress: 55.46% done
Progress: 64.71% done
Progress: 73.95% done
Progress: 83.19% done
Progress: 92.44% done


In [None]:
train_x.shape, train_y.shape, test_x.shape, test_y.shape

(torch.Size([120000, 768]),
 torch.Size([120000]),
 torch.Size([7600, 768]),
 torch.Size([7600]))

In [None]:
train_x.device, train_y.device, test_x.device, test_y.device

(device(type='cpu'),
 device(type='cpu'),
 device(type='cpu'),
 device(type='cpu'))

In [None]:
train_dataset2 = MyDataset(train_x, train_y)
test_dataset2 = MyDataset(test_x, test_y)
train_loader, test_loader = get_loaders(
    train_dataset2, test_dataset2,
    batch_size=BATCH_SIZE * 2, num_workers=1
)

In [None]:
NUM_EPOCHS = 50
instance_train_losses, instance_val_losses, instance_val_accs = train_for_k_epochs(
    NUM_EPOCHS,
    bert,
    mlp,
    tokenizer,
    optim,
    train_loader,
    test_loader,
    epoch_offset=0,
    x_is_cls=True  # the input is the bert cls token;
)

Epoch 0:
--------------------
train_loss=0.60311
val_acc=0.827	val_loss=0.48858
Epoch 1:
--------------------
train_loss=0.46369
val_acc=0.781	val_loss=0.56927
Epoch 2:
--------------------
train_loss=0.43373
val_acc=0.819	val_loss=0.50623
Epoch 3:
--------------------
train_loss=0.41952
val_acc=0.842	val_loss=0.44147
Epoch 4:
--------------------
train_loss=0.40912
val_acc=0.859	val_loss=0.39412
Epoch 5:
--------------------
train_loss=0.40158
val_acc=0.858	val_loss=0.39736
Epoch 6:
--------------------
train_loss=0.39514
val_acc=0.855	val_loss=0.40200
Epoch 7:
--------------------
train_loss=0.38705
val_acc=0.853	val_loss=0.40167
Epoch 8:
--------------------
train_loss=0.37761
val_acc=0.857	val_loss=0.39522
Epoch 9:
--------------------
train_loss=0.37211
val_acc=0.852	val_loss=0.40112
Epoch 10:
--------------------
train_loss=0.37070
val_acc=0.869	val_loss=0.36864
Epoch 11:
--------------------
train_loss=0.36484
val_acc=0.853	val_loss=0.40091
Epoch 12:
--------------------
train_l

### Save model;

In [None]:
decode = {
    0: 'World',
    1: 'Sports',
    2: 'Business',
    3: 'Sci/Tech'
}

In [None]:
outs = model_forward(['Cristiano Ronaldo scored a goal!',
                      'Investors are bearish on the french bond market.',
                      'The new transformer model beats the performance of Recurrent Neural Networks'],
                     tokenizer, bert, mlp, device, x_is_cls=False).argmax(-1)
outs

tensor([1, 2, 3], device='cuda:0')

In [None]:
[decode[v] for v in outs.to('cpu').tolist()]

['Sports', 'Business', 'Sci/Tech']

In [None]:
from pathlib import Path
p = Path('.').absolute() / 'checkpoints'
p.mkdir(parents=True, exist_ok=True)

In [None]:
torch.save(
    {
        'epochs_done': 50,
        'bert_was_trained': False,
        'batch_size': BATCH_SIZE * 2,
        'mlp_state_dict': mlp.state_dict(),
        'mlp_optim_state_dict': optim.state_dict()
    },
    p / 'mlp_agnews.ckpt')

In [None]:
! ls -a ~/.cache/huggingface/

.  ..  datasets  hub


### Done;