In [None]:
import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Union, Any

import torch
from dotenv import load_dotenv
from numpy.typing import NDArray
from rich import print
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from transformers import RobertaModel, RobertaTokenizer
import sys
import numpy as np

In [None]:
ENV_Path = Path("../envs/n24.env")
load_dotenv(str(ENV_Path))

In [None]:
data_dir = Path(os.environ.get("DATASET_ROOT"))
src_dir = Path(os.environ.get("SOURCE_PATH"))

In [None]:
sys.path.append(src_dir)

In [None]:
sys.path.append("/Users/vigneshkannan/Documents/Projects/MultiLabel_N24/")
from src.preprocess.roberta_preprocessor import RoBERTaPreprocessor
preprocessor = RoBERTaPreprocessor(max_length=512)

### Data-Module

In [None]:
from src.utils.data import load_datajson

In [None]:
train_texts, train_labels, num_classes = load_datajson(
    data_dir=data_dir / "news" / "nytimes_train.json"
)

ftune_texts, ftune_labels, num_classes = load_datajson(
    data_dir=data_dir / "news" / "nytimes_train.json"
)

In [None]:
class News24Dataset(Dataset):
    def __init__(
        self,
        texts: Union[NDArray, List],
        labels: List,
        preprocessor: RoBERTaPreprocessor,
    ):
        self.texts = texts
        self.labels = labels
        self.preprocessor = preprocessor

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        ## Process and encode the indexed text using the preprocessor:
        processed_text = self.preprocessor.process_text(text)
        encoding = self.preprocessor.encode_for_model(processed_text)
        label_tensor = torch.tensor(label, dtype=torch.long)
        if encoding is None:
            raise Exception(f"Failed to produce encoding for index: {idx}")

        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding[
                "attention_mask"
            ].flatten(),  ## From the sample text: 1 for info, 0 for padding.
            "label": label_tensor,
        }

In [None]:
train_ds = News24Dataset(
    texts=train_texts, labels=train_labels, preprocessor=preprocessor,
)

In [None]:
# ## Trial to check if everything works as expected.
# for _ in tqdm(train_ds):
#     pass

In [None]:
train_ds[0].keys()

### Trainer module to predict News-Category

In [None]:
from torch import nn
class SectionClassifier(nn.Module):
    def __init__(self, n_classes: int, roberta_type: str = "roberta-base", input_type: str = "text") -> None:
        super(SectionClassifier, self).__init__()

        if input_type.lower() != "text":
            raise Exception(f"Unable to support modality: {input_type}. The current setup only supports `text`")
        self.roberta_type = roberta_type
        self.n_classes = n_classes
        self.model = RobertaModel.from_pretrained(self.roberta_type)
        self.dropout = nn.Dropout(p=0.3)
        self.fc = nn.Linear(self.model.config.output_hidden_states, self.n_classes) ## Need to see if we need deeper MLP!

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

### Training:

In [None]:
def train_epoch(
        model: nn.Module, 
        dl: DataLoader,
        loss_fn: Any,
        optimizer: torch.optim,
        device: Union[str, torch.device], 
        n_examples: int) -> Dict:

    model.train()
    losses = []
    correct_preds = 0

    for batch in tqdm(dl):
        optimizer.zero_grad()
        batch = {key: value.to(device) for key, value in batch.items()}
        outputs = model(
            input_ids= batch["input_ids"],
            attention_mask=batch["attention_mask"],
        )

        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, batch['label'])

        correct_preds += torch.sum(preds == batch['label'])
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
    return {
        "acc": correct_preds.double() / n_examples, 
        "avg_loss": np.mean()
        }

def eval_epoch(
        model: nn.Module, 
        dl: DataLoader,
        loss_fn: Any,
        device: Union[str, torch.device], 
        n_examples: int) -> Dict:

    model.eval()
    losses = []
    correct_preds = 0


    for batch in tqdm(dl):
        batch = {key: value.to(device) for key, value in batch.items()}
        outputs = model(
            input_ids= batch["input_ids"],
            attention_mask=batch["attention_mask"],
        )

        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, batch['label'])

        correct_preds += torch.sum(preds == batch['label'])
        losses.append(loss.item())

    return {
        "acc": correct_preds.double() / n_examples, 
        "avg_loss": np.mean()
        }