## Fine Tuning for QA Classification

In [None]:
! pip install transformers datasets seaborn bertviz umap-learn wandb

### Load Dataset
Dataset was downloaded as a CSV from https://www.kaggle.com/datasets/ashpalsingh1525/imdb-movies-dataset

In [None]:
import pandas as pd

df = pd.read_csv('imdb_movies.csv')

In [None]:
print(f"Shape: {df.shape}")
print(f"Columns: {df.columns}\n")
df.head()

### Build Context and QAs

In [None]:
def build_context_string(row):
    context = f"{row['names']} is a {row['genre']} movie released in {row['date_x']} with budget of {row['budget_x']} had revenue {row['revenue']}."

    if(description := row['overview']):
        context += f"The movie is all about {description}"

    return context

In [None]:
def build_question_string(row, context):
    title = str(row['names'])
    genre = str(row['genre'])
    budget = str(row['budget_x'])
    revenue = str(row['revenue'])
    year = str(row['date_x'])
    description = str(row['overview'])

    question_list = [['genre', genre], ['budget', budget], ['revenue', revenue], ['year', year], ['description', description]]
    qa_list = []

    for question, answer in question_list:
        question_1 = f"What is the {question} of the movie {title}?"
        answer_1 = answer
        answer_start_1 = context.find(answer_1)

        if answer_start_1 != -1:
            qa_list.append({
                'question': question_1,
                'id': f"{title}_{question}",
                'answers': [
                    {
                        'text': answer_1,
                        'answer_start': answer_start_1
                    },
                ],
                'is_impossible': False
            })

    return qa_list

### Build SQuaD JSON

In [None]:
def build_squad_dataset(df):
  data_list = []

  for idx, row in df.iterrows():
      context = build_context_string(row)
      qa_list = build_question_string(row, context)

      if(len(qa_list) == 0):
          continue

      data_list.append({
          'title': row['names'],
          'paragraphs': [{
                'context': context,
                'qas': qa_list
            }]
      })

  return {
    'title': 'IMDB Movies QA V1',
    'data': data_list
  }

In [None]:
squad_data = build_squad_dataset(df)
squad_data['data'][:2]

### Convert to Hugging Face dataset

In [None]:
from datasets import load_dataset, Dataset, DatasetDict

def squad_dict_to_dataset(squad_dict):
    records = []

    for entry in squad_dict["data"]:
        title = entry['title']

        for paragraph in entry['paragraphs']:
            context = paragraph['context']
            qas = paragraph['qas']

            for qa in qas:
                question = qa['question']
                id = qa['id']
                is_impossible = qa['is_impossible']
                answers = qa['answers']
                answer_starts = [answer['answer_start'] for answer in answers]
                answer_texts = [answer['text'] for answer in answers]

                records.append({
                    'title': title,
                    'context': context,
                    'question': question,
                    'id': id,
                    'is_impossible': is_impossible,
                    'answers': answers,
                    'answer_starts': answer_starts,
                    'answer_texts': answer_texts
                })

    return Dataset.from_pandas(pd.DataFrame(records))

In [None]:
dataset = squad_dict_to_dataset(squad_data)
dataset

In [None]:
# 70% train, 30% test
train_test = dataset.train_test_split(test_size=0.3)

train_dataset = train_test['train']
test_dataset = train_test['test']

# 20% test, 10% validation
validation_dataset = test_dataset.train_test_split(test_size=1/3)

test_dataset = validation_dataset['train']
validation_dataset = validation_dataset['test']

train_dataset, test_dataset, validation_dataset

### Prepare dataset

In [None]:
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, TrainingArguments, Trainer, default_data_collator
import torch

MODEL = 'distilbert-base-uncased'

tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL)

In [None]:
max_length = 384  # Max length of the encoding
doc_stride = 128  # Stride to handle long contexts

def prepare_train_features(examples):
    # 1) Tokenize questions + contexts
    tokenized = tokenizer(
        examples["question"],
        examples["context"],
        max_length=max_length,
        truncation="only_second",    # We only truncate the context if it's too long
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # 2) "overflow_to_sample_mapping" indicates which original sample each chunk belongs to
    sample_mapping = tokenized.pop("overflow_to_sample_mapping")
    # 3) "offset_mapping" provides (char_start, char_end) for each token in the original text
    offset_mapping = tokenized.pop("offset_mapping")

    # 4) We'll need the answer information to map them back to token indices
    answers = examples["answers"]

    start_positions = []
    end_positions = []

    # 5) Loop over each tokenized chunk
    for i, offsets in enumerate(offset_mapping):
        # The current chunk corresponds to the original example index:
        sample_idx = sample_mapping[i]

        # Each example can have 1 or more answers; here it's typically 1
        answer = answers[sample_idx]

        start_char = answer[0]["answer_start"]
        end_char = start_char + len(answer[0]["text"])

        sequence_ids = tokenized.sequence_ids(i)

        ctx_start = 0
        while sequence_ids[ctx_start] != 1:
            ctx_start += 1
        ctx_end = len(tokenized["input_ids"][i]) - 1
        while sequence_ids[ctx_end] != 1:
            ctx_end -= 1

        # If the answer text is not in this chunk (overflow window),
        # set start_positions/end_positions to something neutral (e.g. the start of the context)
        if not (start_char < offsets[ctx_end][1] and end_char > offsets[ctx_start][0]):
            start_positions.append(ctx_start)
            end_positions.append(ctx_start)
            continue

        # Otherwise, find the first token that starts after or at the answer’s start_char
        start_idx = ctx_start
        while start_idx <= ctx_end and offsets[start_idx][0] <= start_char:
            start_idx += 1
        start_positions.append(start_idx - 1)

        # Similarly, find the last token that ends before or at the answer’s end_char
        end_idx = ctx_end
        while end_idx >= ctx_start and offsets[end_idx][1] >= end_char:
            end_idx -= 1
        end_positions.append(end_idx + 1)

    # 6) Store these positions in the returned dictionary
    tokenized["start_positions"] = start_positions
    tokenized["end_positions"] = end_positions
    return tokenized

In [None]:
dataset = DatasetDict({
    'train': train_dataset,
    'test': test_dataset,
    'validation': validation_dataset
})

dataset

In [None]:
dataset['train'].column_names

In [None]:
train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=dataset["train"].column_names)
validation_dataset = validation_dataset.map(prepare_train_features, batched=True, remove_columns=dataset["validation"].column_names)

train_dataset, validation_dataset

### Model Training

In [None]:
def show_model_info(model, show_layers=False):
    """Comprehensive model inspection"""
    config = model.config
    architecture = None
    model_heads = []
    model_type = "Unknown"
    id2label = None
    label2id = None
    merged_labels = None
    quant_type = "None"
    q = 0

    gbs = model.get_memory_footprint() / 1e9
    param_count = model.num_parameters()

    # Model architecture
    if hasattr(config, 'architectures') and config.architectures:
        architecture = config.architectures[0]

    # Model heads
    try:
        if hasattr(model, 'base_model'):
            for module in model.modules():
                model_heads.append(type(module).__name__)

                if module == model.base_model:
                    break
        else:
            for name, module in model.named_children()[:5]:
                model_heads.append(f"{name}(type(module).__name__))")

        # Clean model head list
        model_heads = list(dict.fromkeys(model_heads))[:10]
    except Exception as e:
        model_heads = [f"Detection failed: {str(e)[:50]}"]

    # Detect quantization
    if hasattr(config, "quantization_config") and config.quantization_config is not None:
        q_config = config.quantization_config

        if hasattr(q_config, "load_in_4bit") and q_config.load_in_4bit == True:
            q = 4
            quant_type = f"4-bit ({getattr(q_config, 'bnb_4bit_quant_type', 'unknown')})"
        elif hasattr(q_config, "load_in_8bit") and q_config.load_in_8bit == True:
            q = 8
            quant_type = "8-bit"
    else:
        if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
            q = config.torch_dtype.itemsize * 8
            quant_type = f"FP{q} ({config.torch_dtype})"

    # Model type detection
    if hasattr(model.config, 'model_type'):
        model_type = model.config.model_type

    # Label detection
    if  hasattr(model.config, 'label2id') and model.config.label2id is not None and hasattr(config, 'id2label') and config.id2label is not None:

        id2label = model.config.id2label
        label2id = model.config.label2id

        try:
            # Check for label consistency
            label2id_swap = {str(v): k for k, v in label2id.items()}
            id2label_str = {str(k): v for k, v in id2label.items()}

            if id2label_str != label2id_swap:
                merged_labels = {}
                for k, v in id2label.items():
                    key = int(k) if isinstance(k, str) else k
                    merged_labels[key] = [v]

                for k, v in label2id.items():
                    try:
                        v_int = int(v)

                        if v_int in merged_labels:
                            merged_labels[v_int].append(k)
                            merged_labels[v_int] = list(set(merged_labels[v_int]))
                        else:
                            merged_labels[v_int] = [k]
                    except ValueError:
                        continue

        except Exception as e:
            print(f"Label validation error: {e}")

    # Basic model info
    print(f"{'='*55}")
    print(f"MODEL: {getattr(config, '_name_or_path', 'Unknown')}")
    print(f"{'='*55}")

    print(f"Model Type: {model_type}")

    if architecture is not None:
        print(f"Architecture: {architecture}")

    if len(model_heads) > 0:
        print(f"Model Structure: {' → '.join(model_heads)}")

    if hasattr(config, "problem_type") and config.problem_type is not None:
        print(f"Problem Type: {config.problem_type}")

    if hasattr(config, "vocab_size"):
        print(f"Vocab Size: {config.vocab_size:,}")

    if id2label is not None:
        print("\nLabel Info:")

        if merged_labels is None:
            print("  ✅ id2label and label2id match")
            print(f"  Label count: {len(id2label)}")

            if len(id2label) <= 10:
                print(f"  Labels: {id2label}")
            else:
                sample_labels = dict(list(id2label.items())[:5])
                print(f"  Labels (sample): {sample_labels}... (+{len(id2label)-5} more)")
        else:
            print("  ⚠️ WARNING: Model id2label and label2id don't match")
            print(f"  Merged labels: {merged_labels}")

    print(f"\nParameters: {param_count:,}")
    print(f"Quantization: {quant_type}")
    print(f"Memory (actual): {gbs:.2f} GB")
    print(f"Memory (FP32 equiv): {param_count*4/1e9:.2f} GB")

    if gbs > 0:
        print(f"Memory savings: {((param_count*4/1e9 - gbs) / (param_count*4/1e9) * 100):.1f}%")

    # Device info
    device = next(model.parameters()).device
    print(f"\nDevice: {device}")

    # Check if all components on same device
    devices = set()
    for name, param in model.named_parameters():
        devices.add(str(param.device))
    for name, buffer in model.named_buffers():
        devices.add(str(buffer.device))

    if len(devices) > 1:
        print(f"⚠️  WARNING: Model spans multiple devices: {devices}")
    else:
        print(f"✅ All components on: {device}")

    # Add training state info
    if hasattr(model, 'training'):
        mode = "Training" if model.training else "Evaluation"
        print(f"\nMode: {mode}")

    # Memory per layer breakdown
    if show_layers:
        print(f"\n{'Layer Breakdown':^55}")
        print(f"{'Layer':<30} {'Parameters':<14} {'Device'}")
        print("-" * 55)
        for name, param in model.named_parameters():
            # Only show layers with >1M params
            if param.numel() > 1000000:
                print(f"{name[:28]:<30} {param.numel():>10,} {str(param.device):>10}")

In [None]:
import wandb
model = DistilBertForQuestionAnswering.from_pretrained(MODEL)
training_dir = 'distilbert-imdq-qa'
batch_size = 8

show_model_info(model)

training_args = TrainingArguments(
    output_dir=training_dir,
    overwrite_output_dir=True,
    eval_strategy='epoch',
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    save_total_limit=1,
    save_steps=500,
    logging_steps=10
)

In [None]:
trainer = Trainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    data_collator=default_data_collator
)

wandb.init(project=training_dir)

In [None]:
trainer.train()

In [None]:
## Log into Hugging Face
from google.colab import userdata
from huggingface_hub import login
login(token=userdata.get('HF_TOKEN'))

In [None]:
trainer.push_to_hub()

In [None]:
from transformers import pipeline

HF_USER = "shayharding"
qa = pipeline("question-answering", model=HF_USER + "/" + training_dir)

In [None]:
data_row = df.iloc[0]
context = build_context_string(data_row)
question = f"What is the genre of {data_row['names']}?"

print(f"Context: {context}\n")
print(f"Question: {question}\n")
print(f"Answer: {qa(question=question, context=context)}")