In [None]:
# !pip install wandb evaluate seqeval 

In [None]:
from typing import Optional, Dict, List, Any, Union

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import evaluate
from tqdm import tqdm

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader, Subset
from torchmetrics.classification import F1Score

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    AutoConfig,
    get_cosine_schedule_with_warmup,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from spacy import displacy

import wandb
from kaggle_secrets import UserSecretsClient

from sklearn.model_selection import KFold


##### add secret for wandb

In [None]:
user_secrets = UserSecretsClient()
my_secret = user_secrets.get_secret("wandb_secret") 
wandb.login(key=my_secret)

In [28]:
df = pd.read_csv("/kaggle/input/feedback-prize-2021/train.csv")
df.describe()

Unnamed: 0,discourse_id,discourse_start,discourse_end
count,144293.0,144293.0,144293.0
mean,1618936000000.0,959.818855,1200.791203
std,2491895000.0,921.054471,1010.457306
min,1614351000000.0,0.0,3.0
25%,1616884000000.0,277.0,422.0
50%,1618862000000.0,685.0,927.0
75%,1621222000000.0,1404.0,1696.0
max,1623614000000.0,7510.0,7947.0


### Check bad characters

In [None]:
def fix_mixed_encoding(input_path, output_path):
    try:
        try:
            with open(input_path, 'r', encoding='utf-8') as f:
                content = f.read()
                
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(content)
            return True
            
        except UnicodeDecodeError:
            with open(input_path, 'rb') as f:
                byte_content = f.read()
            
            try:
                decoded_content = byte_content.decode('cp1252')
                with open(output_path, 'w', encoding='utf-8') as f:
                    f.write(decoded_content)
                return True
                
            except UnicodeDecodeError:
                fixed_content = bytearray()
                i = 0
                while i < len(byte_content):
                    for j in range(4, 0, -1):  # Try 4, 3, 2, 1 bytes
                        if i + j <= len(byte_content):
                            try:
                                chunk = byte_content[i:i+j].decode('utf-8')
                                fixed_content.extend(byte_content[i:i+j])
                                i += j
                                break
                            except UnicodeDecodeError:
                                continue
                    else:
                        try:
                            chunk = byte_content[i:i+1].decode('cp1252')
                            fixed_content.extend(chunk.encode('utf-8'))
                        except UnicodeDecodeError:
                            fixed_content.append(byte_content[i])
                        i += 1
                
                with open(output_path, 'wb') as f:
                    f.write(fixed_content)
                return True
                
    except Exception as e:
        print(f"Error processing {input_path}: {e}")
        return False

def process_directory(source_dir, destination_dir):
    success_count = 0
    failure_count = 0
    os.makedirs(destination_dir, exist_ok=True)
    
    for root, dirs, files in os.walk(source_dir):
        for file in files:
            source_file_path = os.path.join(root, file)
            
            relative_path = os.path.relpath(source_file_path, source_dir)
            destination_file_path = os.path.join(destination_dir, relative_path)
            os.makedirs(os.path.dirname(destination_file_path), exist_ok=True)
            
            if fix_mixed_encoding(source_file_path, destination_file_path):
                success_count += 1
            else:
                failure_count += 1
    
    print(f"\nProcessing complete. Successfully fixed: {success_count}, Failed: {failure_count}")

# Example usage (commented out)
# source_directory = "/kaggle/input/feedback-prize-2021/train"
# destination_directory = "/kaggle/working/feedback-prize-2021/train"
# process_directory(source_directory, destination_directory)


Processing complete. Successfully fixed: 15594, Failed: 0


#### **Cluster into grades

In [None]:
df_cluster = pd.read_csv("/kaggle/input/feedback-clustering/essays_with_predictions.csv")

In [30]:
df_cluster['id'] = df_cluster['file_name'].str.replace('.txt', '', regex=False)
df = df.merge(df_cluster[['id', 'predicted_grade_cluster']], on='id', how='left')

#### Dataset

In [None]:
def get_tag(anns):
    discourse_types = anns['discourse_type'].unique()
    tag_map = {'O': 0}
    tag_idx = 1
    for discourse_type in discourse_types:
        tag_map[f'B-{discourse_type}'] = tag_idx
        tag_idx += 1
        tag_map[f'I-{discourse_type}'] = tag_idx
        tag_idx += 1
    return tag_map


class DataProcessor():
    def __init__(self, text_dir, anns, tokenizer, tag_map, max_length=2048, include_cluster=True):
        self.text_dir = text_dir
        self.anns = anns
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.include_cluster = include_cluster
        self.file_ids = self.anns['id'].unique()
        self.discourse_types = self.anns['discourse_type'].unique()
        self.tag_map = tag_map
            
        if self.include_cluster:
            self.cluster_tokens = [f"[CLUSTER_{i}]" for i in range(6)]
            special_tokens = {"additional_special_tokens": self.cluster_tokens}
            self.tokenizer.add_special_tokens(special_tokens)
            
    def __len__(self):
        return len(self.file_ids)
        
    def __getitem__(self, idx):
        file_id = self.file_ids[idx]
        
        with open(os.path.join(self.text_dir, f"{file_id}.txt"), 'r', encoding='utf-8') as f:
            text = f.read().strip()
        text = text.replace("\n\n", " \n\n")
        
        file_annotations = self.anns[self.anns['id'] == file_id]
        words = text.split(" ")
        
        word_labels = ['O'] * len(words)
        
        for _, row in file_annotations.iterrows():
            discourse_type = row['discourse_type']
            word_indices = [int(idx) for idx in str(row['predictionstring']).split()]
            
            for i, word_idx in enumerate(word_indices):
                if word_idx < len(word_labels):
                    word_labels[word_idx] = f'B-{discourse_type}' if i == 0 else f'I-{discourse_type}'
        
        cluster_token_word = None
        if self.include_cluster and not file_annotations.empty:
            cluster = file_annotations['predicted_grade_cluster'].iloc[0]
            cluster_token_word = self.cluster_tokens[int(cluster)]
        
        if cluster_token_word:
            tokenizer_input = [cluster_token_word] + words
        else:
            tokenizer_input = words
        
        encodings = self.tokenizer(
            tokenizer_input,
            is_split_into_words=True,
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        word_ids = encodings.word_ids(batch_index=0)
        
        token_labels = []
        for word_id in word_ids:
            if word_id is None:
                # Special tokens like [CLS], [SEP], [PAD]
                token_labels.append(-100)
            elif self.include_cluster and cluster_token_word and word_id == 0:
                # cluster token
                token_labels.append(-100)
            else:
                adjusted_word_id = word_id - 1 if (self.include_cluster and cluster_token_word) else word_id
                if adjusted_word_id < len(word_labels):
                    token_labels.append(self.tag_map[word_labels[adjusted_word_id]])
                else:
                    token_labels.append(self.tag_map['O'])
        
        return {
            'input_ids': encodings.input_ids[0],
            'attention_mask': encodings.attention_mask[0],
            'labels': torch.tensor(token_labels),
            'word_ids' : word_ids,
            'file_id': file_id,
            "words" : words
        }

           

##### Collator

In [None]:

class NERDataCollator:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, max_length: int = 2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_ids = [feature['input_ids'] for feature in features]
        attention_mask = [feature['attention_mask'] for feature in features]
        labels = [feature['labels'] for feature in features]
        file_ids = [feature['file_id'] for feature in features]
        words = [feature['words'] for feature in features]
        word_ids = [feature['word_ids'] for feature in features]
        
        input_ids = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids]
        attention_mask = [mask.tolist() if isinstance(mask, torch.Tensor) else mask for mask in attention_mask]
        labels = [lbl.tolist() if isinstance(lbl, torch.Tensor) else lbl for lbl in labels]
        
        batch_max_length = min(max(len(ids) for ids in input_ids), self.max_length)
        padded_input_ids = []
        padded_attention_mask = []
        padded_labels = []
        pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
        
        for ids, mask, lbl in zip(input_ids, attention_mask, labels):
            padding_length = batch_max_length - len(ids)
            padded_input_ids.append(ids + [pad_token_id] * padding_length)
            padded_attention_mask.append(mask + [0] * padding_length)
            padded_labels.append(lbl + [-100] * padding_length)
        
        batch = {
            "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(padded_attention_mask, dtype=torch.long),
            "labels": torch.tensor(padded_labels, dtype=torch.long),
            "file_ids": file_ids,
            "words": words,
            "word_ids": word_ids
        }
        
        return batch

##### Validation

In [14]:
def convert_predictions_to_predictionstring(file_id, word_ids, words, predicted_labels, tag_map):
    idx_to_tag = {v: k for k, v in tag_map.items()}
    
    # Check for cluster token
    has_cluster_token = False
    if len(words) > 0 and isinstance(words[0], str) and words[0].startswith('[CLUSTER_'):
        has_cluster_token = True
    
    # Track word-level predictions
    final_word_predictions = {}
    for token_idx, (word_idx, pred_label) in enumerate(zip(word_ids, predicted_labels)):
        pred_label = int(pred_label)
        if word_idx is None or pred_label == -100:
            continue
        
        if has_cluster_token and word_idx == 0:
            continue
            
        adjusted_word_idx = word_idx - 1 if has_cluster_token else word_idx
            
        if adjusted_word_idx < len(words):
            if pred_label in idx_to_tag:
                tag = idx_to_tag[pred_label]
                # Only include non-O tags in final predictions
                if tag != 'O':
                    final_word_predictions[adjusted_word_idx] = tag
    
    discourse_spans = []
    current_type = None
    current_indices = []
    
    all_word_indices = set(range(len(words)))
    if has_cluster_token:
        all_word_indices = set(range(len(words) - 1))  # cluster token
    
    for i in sorted(all_word_indices):
        if i in final_word_predictions:
            tag = final_word_predictions[i]
            
            if tag.startswith('B-') or (current_type is None and tag.startswith('I-')):
                if current_type and current_indices:
                    discourse_spans.append({
                        'discourse_type': current_type,
                        'predictionstring': ' '.join(map(str, current_indices))
                    })
                
                
                current_type = tag[2:]  # Remove B- and I-
                current_indices = [i]
                
            elif tag.startswith('I-') and current_type == tag[2:]:
                if current_indices and i == current_indices[-1] + 1:
                    current_indices.append(i)
                else:
                    if current_type and current_indices:
                        discourse_spans.append({
                            'discourse_type': current_type,
                            'predictionstring': ' '.join(map(str, current_indices))
                        })
                    current_type = tag[2:]
                    current_indices = [i]
        else:
            if current_type and current_indices:
                discourse_spans.append({
                    'discourse_type': current_type,
                    'predictionstring': ' '.join(map(str, current_indices))
                })
                current_type = None
                current_indices = []
    
    if current_type and current_indices:
        discourse_spans.append({
            'discourse_type': current_type,
            'predictionstring': ' '.join(map(str, current_indices))
        })
    
    for span in discourse_spans:
        span['id'] = file_id
    
    return discourse_spans

def merge_consecutive_spans(spans_list):
    if not spans_list:
        return pd.DataFrame(columns=['id', 'discourse_type', 'predictionstring'])
    
    processed_spans = []
    for span in spans_list:
        new_span = span.copy()
        new_span['indices'] = set(map(int, span['predictionstring'].split()))
        processed_spans.append(new_span)
    
    processed_spans.sort(key=lambda x: min(x['indices']))
    
    merged_spans = []
    current_span = None
    
    for span in processed_spans:
        if current_span is None:
            current_span = span
        elif (span['discourse_type'] == current_span['discourse_type'] and 
              span['id'] == current_span['id'] and
              any(idx == max(current_span['indices']) + 1 for idx in span['indices'])):
            current_span['indices'].update(span['indices'])
        else:
            indices_list = sorted(current_span['indices'])
            current_span['predictionstring'] = ' '.join(map(str, indices_list))
            del current_span['indices']
            merged_spans.append(current_span)
            current_span = span
    
    if current_span:
        indices_list = sorted(current_span['indices'])
        current_span['predictionstring'] = ' '.join(map(str, indices_list))
        del current_span['indices']
        merged_spans.append(current_span)
    
    df = pd.DataFrame(merged_spans)
    
    if not df.empty:
        df = df[['id', 'discourse_type', 'predictionstring']]
    
    return df

##### Model

In [53]:
class TokenClassificationModule(pl.LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        tokenizer,
        num_labels: int = 15,
        learning_rate: float = 2e-5,
        weight_decay: float = 0.01,
        warmup_steps: int = 500,
        total_steps: Optional[int] = None,
        freeze_layers: int = 0,
        id2label=None,
        label_names=None,
    ):

        super().__init__()
        
        self.save_hyperparameters(ignore=["tokenizer"])
        
        config = AutoConfig.from_pretrained(model_name_or_path)
        config.num_labels = num_labels

        self.model = AutoModelForTokenClassification.from_pretrained(
            model_name_or_path,
            config=config,
        )
        if tokenizer is not None:
            self.model.resize_token_embeddings(len(tokenizer))
        
        self.train_f1 = F1Score(task="multiclass", num_classes=num_labels, ignore_index=-100)
        self.val_f1 = F1Score(task="multiclass", num_classes=num_labels, ignore_index=-100)
        
        self.seqeval_metric = evaluate.load("seqeval")
        self.id2label = id2label
        self.label_names = label_names
        
        self.val_predictions = []
        self.val_labels = []
        
        self.validation_step_outputs = []
        
    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        
        loss = outputs.loss
        logits = outputs.logits

        preds = torch.argmax(logits, dim=-1)
        self.train_f1(preds, batch["labels"])
        
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        
        loss = outputs.loss
        logits = outputs.logits
        
        preds = torch.argmax(logits, dim=-1)
        self.val_f1(preds, batch["labels"])
        
        self.val_predictions.extend(preds.detach().cpu().tolist())
        self.val_labels.extend(batch["labels"].detach().cpu().tolist())
        
        self.log("val_loss_step", loss, on_step=True, prog_bar=True)
        
        self.validation_step_outputs.append(loss.detach())
        
        return loss
    
    def on_train_epoch_end(self):
        train_f1 = self.train_f1.compute()
        self.log("train_f1", train_f1, on_epoch=True, prog_bar=True)
        self.train_f1.reset()
    
    def on_validation_epoch_end(self):
        avg_val_loss = torch.stack(self.validation_step_outputs).mean()
        self.log("val_loss", avg_val_loss, on_epoch=True, prog_bar=True, sync_dist=True)
        self.validation_step_outputs.clear()
        
        # Standard F1 metric
        val_f1 = self.val_f1.compute()
        self.log("val_f1", val_f1, on_epoch=True, prog_bar=True, sync_dist=True)
        self.val_f1.reset()
        
        if self.id2label is not None:
            metrics = self.compute_seqeval_metrics(self.val_predictions, self.val_labels)
            self.log("val_seqeval_precision", metrics["precision"], on_epoch=True, prog_bar=True, sync_dist=True)
            self.log("val_seqeval_recall", metrics["recall"], on_epoch=True, prog_bar=True, sync_dist=True)
            self.log("val_seqeval_f1", metrics["f1"], on_epoch=True, prog_bar=True, sync_dist=True)
            self.log("val_seqeval_accuracy", metrics["accuracy"], on_epoch=True, prog_bar=True, sync_dist=True)
        
        self.val_predictions = []
        self.val_labels = []
    
    def compute_seqeval_metrics(self, predictions, labels):
        true_labels = [[self.id2label[l] for l in label if l != -100] for label in labels]
        true_predictions = [
            [self.id2label[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        all_metrics = self.seqeval_metric.compute(predictions=true_predictions, references=true_labels)
        return {
            "precision": all_metrics["overall_precision"],
            "recall": all_metrics["overall_recall"],
            "f1": all_metrics["overall_f1"],
            "accuracy": all_metrics["overall_accuracy"],
        }
    
    def predict_step(self, batch, batch_idx):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"]
        )
        
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        
        return {
            "logits": logits,
            "predictions": predictions,
            "essay_id": batch.get("essay_id", None),
            "words": batch.get("words", None),
            "word_labels": batch.get("word_labels", None)
        }
    
    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters, 
            lr=self.hparams.learning_rate
        )
        
        if self.hparams.total_steps is None:
            return optimizer
        
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.hparams.total_steps,
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

##### refine output with beam search

In [22]:
def beam_search(logits, beam_size = 5):
    preds = []
    probs = torch.nn.functional.softmax(logits, dim=-1)

    # take top-k labels for each token
    topk_probs, topk_indices = torch.topk(
        probs, beam_size, dim=-1
    )  # shape: (batch_size, seq_len, beam_width)

    # choose the most probable sequence from beam search
    best_seq = topk_indices[:, :, 0].cpu().numpy()

    preds.extend(best_seq)

    return preds

##### Train/Val in KFolds

In [57]:

def train_with_folds_lightning(
    data_processor,
    model_class,
    model_name_or_path,
    anns,
    data_collator,
    id2label,
    n_splits=5, 
    batch_size=8,
    num_workers=4,
    max_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    project_name="token-classification",
    entity_name="vvilgurin-igor-sikorsky-kyiv-polytechnic-institute",
    experiment_name="token-classification-cv-1",
    patience=3,
    using_clusters=True
):

    print("init wandb")
    wandb.init(
        project=project_name,
        entity=entity_name,
        name=experiment_name,
        config={
            "n_splits": n_splits,
            "batch_size": batch_size,
            "max_epochs": max_epochs,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "model_name": model_name_or_path,
            "using_clusters": using_clusters
        }
    )

    using_clusters = hasattr(data_processor, 'include_cluster') and data_processor.include_cluster
    file_id_to_idx = {file_id: idx for idx, file_id in enumerate(data_processor.file_ids)}
    unique_ids = data_processor.file_ids
    
    print("splitting folds")
    if using_clusters:
        file_clusters = {}
        for file_id in unique_ids:
            file_anns = anns[anns['id'] == file_id]
            if not file_anns.empty and 'predicted_grade_cluster' in file_anns.columns:
                # Handle potential NaN values
                cluster_value = file_anns['predicted_grade_cluster'].iloc[0]
                if pd.isna(cluster_value):
                    file_clusters[file_id] = -1
                else:
                    file_clusters[file_id] = int(cluster_value)
            else:
                file_clusters[file_id] = -1
                
        print("grouping clusters")
        cluster_to_files = {}
        for file_id, cluster in file_clusters.items():
            if cluster not in cluster_to_files:
                cluster_to_files[cluster] = []
            cluster_to_files[cluster].append(file_id)
        
        folds = [[] for _ in range(n_splits)]
        print("filling the clusters")
        for cluster, files in cluster_to_files.items():
            import random
            random.seed(42) 
            random.shuffle(files)
            
            for i, file_id in enumerate(files):
                fold_idx = i % n_splits
                folds[fold_idx].append(file_id)
        
        print("convert to train/test")
        fold_indices = []
        for i in range(n_splits):
            val_file_ids = folds[i]
            train_file_ids = [fid for j in range(n_splits) if j != i for fid in folds[j]]
            
            train_indices = [file_id_to_idx[fid] for fid in train_file_ids]
            val_indices = [file_id_to_idx[fid] for fid in val_file_ids]
            
            fold_indices.append((train_indices, val_indices))
    else:
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
        fold_indices = []
        
        for train_idx, val_idx in kf.split(range(len(unique_ids))):
            fold_indices.append((train_idx.tolist(), val_idx.tolist()))
    
    all_results = []
    best_models = []
    
    print("Training start")
    for fold, (train_indices, val_indices) in enumerate(tqdm(fold_indices, desc="Cross-validation folds")):
        fold_name = f"{experiment_name}_fold_{fold+1}"
        print(f"Training fold {fold+1}/{n_splits}")
        
        train_dataset = Subset(data_processor, train_indices)
        val_dataset = Subset(data_processor, val_indices)
        
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=data_collator,
            num_workers=num_workers
        )
        
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=data_collator,
            num_workers=num_workers
        )
        
        wandb_logger = WandbLogger(
            project=project_name,
            entity=entity_name,
            name=fold_name,
            group=experiment_name, 
            log_model=True
        )
        
        total_steps = len(train_dataloader) * max_epochs
        warmup_steps = int(0.1 * total_steps)  
        
        lightning_model = model_class(
            model_name_or_path=model_name_or_path,
            num_labels=data_processor.num_labels if hasattr(data_processor, 'num_labels') else 15,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            warmup_steps=warmup_steps,
            total_steps=total_steps, 
            tokenizer=data_processor.tokenizer,
            id2label=id2label,
            label_names=list(tags.keys())
        )
        
        wandb_logger.log_hyperparams({
            "fold": fold + 1,
            "train_size": len(train_dataset),
            "val_size": len(val_dataset)
        })
        
        checkpoint_callback = ModelCheckpoint(
            dirpath=f"checkpoints/{experiment_name}/fold_{fold+1}",
            filename="model-{epoch:02d}-{val_f1:.4f}",
            monitor="val_f1",
            mode="max",
            save_top_k=1,
            save_weights_only=False
        )
        
        early_stop_callback = EarlyStopping(
            monitor="val_f1",
            min_delta=0.001,
            patience=patience,
            verbose=True,
            mode="max"
        )
        
        trainer = pl.Trainer(
            max_epochs=max_epochs,
            devices=1, 
            accelerator="auto",
            logger=wandb_logger,
            callbacks=[checkpoint_callback, early_stop_callback],
            enable_checkpointing=True,
            deterministic=True,
            accumulate_grad_batches=8
        )
        
        trainer.fit(
            model=lightning_model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader
        )
        
        fold_results = trainer.validate(lightning_model, val_dataloader)
        all_results.append(fold_results)
        
        best_models.append({
            "fold": fold + 1,
            "val_f1": fold_results[0]["val_f1"],
            "checkpoint_path": checkpoint_callback.best_model_path
        })

        print(f"Fold {fold+1} validation results: {fold_results}")
        
        wandb.finish()
    
    best_model_info = max(best_models, key=lambda x: x["val_f1"])
    
    val_losses = [result[0]["val_loss"] for result in all_results]
    val_f1s = [result[0]["val_f1"] for result in all_results]
    
    cv_results = {
        "mean_val_loss": np.mean(val_losses),
        "std_val_loss": np.std(val_losses),
        "mean_val_f1": np.mean(val_f1s),
        "std_val_f1": np.std(val_f1s),
        "best_model": best_model_info
    }
    
    wandb.init(
        project=project_name,
        entity=entity_name,
        name=f"{experiment_name}_cv_summary",
        group=experiment_name
    )
    
    wandb.log({
        "cv_mean_val_loss": cv_results["mean_val_loss"],
        "cv_std_val_loss": cv_results["std_val_loss"],
        "cv_mean_val_f1": cv_results["mean_val_f1"],
        "cv_std_val_f1": cv_results["std_val_f1"],
        "best_fold": best_model_info["fold"],
        "best_fold_f1": best_model_info["val_f1"]
    })
    
    fold_table = wandb.Table(columns=["Fold", "Validation Loss", "Validation F1"])
    for i, result in enumerate(all_results):
        fold_table.add_data(i+1, result[0]["val_loss"], result[0]["val_f1"])
    
    wandb.log({"fold_results_table": fold_table})
    
    wandb.finish()
    
    # Print summary
    print("\n" + "="*50)
    print(f"Cross-validation completed with {n_splits} folds")
    print(f"Mean Validation F1: {cv_results['mean_val_f1']:.4f} ± {cv_results['std_val_f1']:.4f}")
    print(f"Mean Validation Loss: {cv_results['mean_val_loss']:.4f} ± {cv_results['std_val_loss']:.4f}")
    print(f"Best model from fold {best_model_info['fold']} with F1: {best_model_info['val_f1']:.4f}")
    print(f"Best model path: {best_model_info['checkpoint_path']}")
    print("="*50)
    
    return cv_results


In [None]:
tags = get_tag(df)
id2label = {id: label for label, id in tags.items()}
anns = df
data_processor = DataProcessor(text_dir, anns, tokenizer, tags)
data_collator = NERDataCollator(tokenizer)
# anns = pd.read_csv('/kaggle/input/feedback-prize-2021/train.csv')

cv_results = train_with_folds_lightning(
    data_processor=data_processor,
    model_class=TokenClassificationModule,
    model_name_or_path="google/bigbird-roberta-base", 
    anns=anns,
    id2label = id2label,
    data_collator=data_collator,
    n_splits=3,
    batch_size=2,
    max_epochs=2,
    learning_rate=2e-5,
    experiment_name="big-bird-folds-cluster"
)


init wandb
splitting folds
grouping clusters
filling the clusters
convert to train/test
Training start



Cross-validation folds:   0%|          | 0/3 [00:00<?, ?it/s][A

Training fold 1/3


Some weights of BigBirdForTokenClassification were not initialized from the model checkpoint at google/bigbird-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
Attention type 'block_sparse' is not possible if sequence_length: 699 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1 validation results: [{'val_loss_step_epoch': 0.6806517839431763, 'val_loss': 0.6806506514549255, 'val_f1': 0.7864494919776917, 'val_seqeval_precision': 0.28021547198295593, 'val_seqeval_recall': 0.38007569313049316, 'val_seqeval_f1': 0.3225943148136139, 'val_seqeval_accuracy': 0.7864494919776917}]


0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█████
train_f1,▁█
train_loss,█▄▆▄▅▆▄▅▄▄▆▂▅▂▂▂▃▂▅▂▂▂▂▃▂▃▄▅▄▂▂▁▃▃▃▃▃▃▄▃
trainer/global_step,█▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▂▅▅▅▆▆▇▇▇▇██▂▂▂▂▂▂▃▃▃▄▄
val_f1,▁██
val_loss,█▁▁
val_loss_step_epoch,█▁▁
val_loss_step_step,▃▆▇▆▇▇▅█▆▂▆▄▁▅▄▄▅▅▃▅▄▃▄▄█▆▆▆▅▅▅▅▃▃▂▅▅▅▄▅
val_seqeval_accuracy,▁██
val_seqeval_f1,▁██

0,1
epoch,2.0
train_f1,0.77508
train_loss,0.44424
trainer/global_step,1300.0
val_f1,0.78645
val_loss,0.68065
val_loss_step_epoch,0.68065
val_loss_step_step,0.74082
val_seqeval_accuracy,0.78645
val_seqeval_f1,0.32259



  lambda data: self._console_raw_callback("stderr", data),
Cross-validation folds:  33%|███▎      | 1/3 [41:57<1:23:55, 2517.84s/it][A

Training fold 2/3


Some weights of BigBirdForTokenClassification were not initialized from the model checkpoint at google/bigbird-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Input ids are automatically padded from 717 to 768 to be a multiple of `config.block_size`: 64
Input ids are automatically padded from 779 to 832 to be a multiple of `config.block_size`: 64


Training: |          | 0/? [00:00<?, ?it/s]

Attention type 'block_sparse' is not possible if sequence_length: 537 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...


In [None]:
best_model = TokenClassificationModule.load_from_checkpoint(cv_results["best_model"]["checkpoint_path"])