In [1]:
pip install datasets

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, multiprocess, datasets
Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15


In [2]:
import torch
from transformers import BertForSequenceClassification, BertTokenizer, BertModel, BertConfig
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.utils.rnn import pad_sequence

In [3]:
tokenizer = AutoTokenizer.from_pretrained('google/bert_uncased_L-8_H-512_A-8')
dataset = load_dataset('glue', 'sst2')

num_prompt_tokens =10

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/31.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [4]:
class ModelWithSoftPrompts(torch.nn.Module):
    def __init__(self, model_name, num_prompt_tokens):
        super().__init__()
        self.num_prompt_tokens = num_prompt_tokens
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        for param in self.model.parameters():
            param.requires_grad = False
        self.prompt_embeddings = torch.nn.Embedding(num_prompt_tokens, self.model.config.hidden_size)
        for param in self.prompt_embeddings.parameters():
            param.requires_grad = True

        torch.nn.init.normal_(self.prompt_embeddings.weight, std=0.02)  # Initialize prompts

    def forward(self, input_ids, attention_mask):
        prompt_embeddings = self.prompt_embeddings(torch.arange(self.num_prompt_tokens, device=input_ids.device)).unsqueeze(0)
        prompt_embeddings = prompt_embeddings.repeat(input_ids.shape[0], 1, 1)  # Repeat for batch size

        token_embeddings = self.model.bert.embeddings.word_embeddings(input_ids)

        full_embeddings = torch.cat((prompt_embeddings, token_embeddings), dim=1)

        prompt_attention_mask = torch.ones((attention_mask.shape[0], self.num_prompt_tokens), device=attention_mask.device)
        full_attention_mask = torch.cat((prompt_attention_mask, attention_mask), dim=1)

        outputs = self.model(inputs_embeds=full_embeddings, attention_mask=full_attention_mask)
        return outputs

In [5]:
model = ModelWithSoftPrompts('google/bert_uncased_L-8_H-512_A-8', num_prompt_tokens)

pytorch_model.bin:   0%|          | 0.00/167M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-8_H-512_A-8 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Tokenizer helper to add prompt token IDs
def add_prompt_tokens_to_input_ids(input_ids, num_prompt_tokens, pad_token_id):
    prompt_tokens = torch.full((num_prompt_tokens,), pad_token_id, dtype=torch.long)
    return torch.cat([prompt_tokens, input_ids], dim=0)

# preprocess
def preprocess(example):
    encoded = tokenizer.encode_plus(
        example['sentence'],
        add_special_tokens=True,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    input_ids = encoded['input_ids'][0]
    attention_mask = encoded['attention_mask'][0]

    input_ids = add_prompt_tokens_to_input_ids(input_ids, num_prompt_tokens, tokenizer.pad_token_id)

    attention_mask = torch.cat([torch.ones(num_prompt_tokens), attention_mask], dim=0)

    label = torch.tensor(example['label'], dtype=torch.long)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': label
    }

dataset = dataset.map(preprocess, batched=False)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [7]:
def collate_fn(batch):
    input_ids = pad_sequence([torch.tensor(item['input_ids'], dtype=torch.long) for item in batch],
                             batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence([torch.tensor(item['attention_mask'], dtype=torch.long) for item in batch],
                                  batch_first=True, padding_value=0)
    labels = torch.tensor([item['labels'] for item in batch], dtype=torch.long)
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=collate_fn)



In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = torch.optim.AdamW(model.prompt_embeddings.parameters(), lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 5
total_steps = len(train_loader) * epochs

In [9]:
def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_all_params(model):
    return sum(p.numel() for p in model.parameters())

print(f"Total parameters: {count_all_params(model)}")
print(f"Trainable parameters: {count_trainable_params(model)}")

Total parameters: 41379330
Trainable parameters: 5120


In [None]:
model.train()
for epoch in range(epochs):
    total_loss = 0
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{epochs}", position=0, leave=True, ncols=80) as progress_bar:
        for batch in train_loader:
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = loss_fn(logits, labels)
            loss.backward()

            #torch.nn.utils.clip_grad_norm_(model.prompt_embeddings.parameters(), 1.0)

            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix_str(f"Loss: {loss.item()}")
            progress_bar.update()

    avg_loss = total_loss / len(train_loader)
    print(f"\nEpoch {epoch + 1}/{epochs} - Average Loss: {avg_loss}")

Epoch 1/5:  10%|▎  | 207/2105 [00:37<05:29,  5.76it/s, Loss: 0.6287640929222107]

In [None]:
# Evaluate the model on the test set

model.eval()
total_correct = 0
total_samples = 0

with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        _, predicted = torch.max(logits, 1)
        #print(predicted, labels.squeeze())
        total_correct += (predicted == labels.squeeze()).sum().item()
        total_samples += labels.size(0)

accuracy = total_correct / total_samples
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Test Accuracy: 77.06%
