In [8]:
import yaml
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk

In [2]:
num_cores_avail = max(1, multiprocessing.cpu_count() - 1)

In [3]:
with open("../experiments/configs/pitchfork_cls/main.yaml", 'r') as f:
    main_config = yaml.safe_load(f)

In [4]:
dataset_checkpoint = main_config["dataset_checkpoint"]
dataset_checkpoint_revision = main_config["dataset_checkpoint_revision"]
model_checkpoint = main_config["model_checkpoint"]
model_checkpoint_revision = main_config["model_checkpoint_revision"]

In [5]:
embedding_model = AutoModel.from_pretrained(
    model_checkpoint,
    revision=model_checkpoint_revision
)

tokenizer = AutoTokenizer.from_pretrained(
    model_checkpoint,
    revision=model_checkpoint_revision
)

datasets = load_from_disk("../data/pitchfork/dataset_dbr/")

In [6]:
keeper_cols = ["artist", "album", "year_released", "rating", "input_ids", "attention_mask"]
drop_cols = set(datasets["train"].column_names).difference(set(keeper_cols))

In [7]:
tokenized_datasets = (
    datasets
        .map(lambda examples: tokenizer(examples["review"], padding=True, truncation=True), batched=True, num_proc=num_cores_avail)
        .remove_columns(drop_cols)
)

In [11]:
def collate_reviews(batch):
    # Extract input_ids and labels from the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    ratings = [item['rating'] for item in batch]

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    ratings = torch.tensor(ratings)

    return input_ids, attention_masks, ratings

In [44]:
class TextRegressor(nn.Module):
    def __init__(self, embedder, embed_dim, output_dim=1):
        super().__init__()
        
        # Initialize the encoder (e.g., DistilBERT, BERT, etc.)
        self.embedder = embedder
        
        # Regression head
        self.regression_head = nn.Linear(embed_dim, output_dim)
        
    def forward(self, input_ids, attention_mask):
        # Forward pass through encoder
        embedding = self.embedder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Extract the [CLS] embedding
        embedding = embedding.last_hidden_state[:, 0, :]
        
        # Forward pass through regression head
        yhat = self.regression_head(embedding)
        
        return yhat

In [45]:
valid_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=8, collate_fn=collate_reviews)

In [46]:
for batch_idx, batch in enumerate(valid_dataloader):
    break

In [47]:
input_ids, attention_masks, ratings = batch

In [48]:
input_ids

tensor([[  101, 27166, 13146,  ...,   112,   187,   102],
        [  101, 10657,   117,  ..., 13028,   112,   102],
        [  101,   107, 11065,  ..., 10111, 18850,   102],
        ...,
        [  101, 14178,   152,  ...,   169, 31293,   102],
        [  101, 10117, 10992,  ...,   187, 17038,   102],
        [  101, 11984, 72920,  ..., 51555, 44667,   102]])

In [49]:
tokenizer.decode([101])

'[CLS]'

In [50]:
model = TextRegressor(
    embedding_model,
    embed_dim=embedding_model.config.dim
)

In [51]:
model

TextRegressor(
  (embedder): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Line

In [52]:
with torch.no_grad():
    yhat = model.forward(input_ids=input_ids, attention_mask=attention_masks)

In [53]:
yhat

tensor([[0.0733],
        [0.0557],
        [0.0485],
        [0.0741],
        [0.0602],
        [0.0733],
        [0.0727],
        [0.0574]])