In [None]:
!pip install adapters



In [41]:
MODEL_NAME = "dccuchile/bert-base-spanish-wwm-uncased"
# MODEL_NAME = 'bert-base-uncased'

In [32]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, BertConfig
from peft import get_peft_model, LoraConfig, TaskType

class SentenceBERTContrastive(nn.Module):
    def __init__(self):
        super().__init__()
        base_model = BertModel.from_pretrained(MODEL_NAME)

        # Freeze all base model params
        for param in base_model.parameters():
            param.requires_grad = False

        lora_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["query", "value"]
        )
        self.bert = get_peft_model(base_model, lora_config)

        self.projection = nn.Linear(self.bert.config.hidden_size, 128)  # still project to embedding space

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        return self.projection(cls_embeddings)

In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel, AdamW
from peft import get_peft_model, LoraConfig, TaskType

# Define a model that wraps BERT and applies an adaptor layer.
class TADA(nn.Module):
    def __init__(self, model_name='dccuchile/bert-base-spanish-wwm-uncased'):
        super(TADA, self).__init__()
        base_model = BertModel.from_pretrained(model_name)
        # Freeze all base model params
        for param in base_model.parameters():
            param.requires_grad = False

        lora_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["query", "value"]
        )
        self.bert = get_peft_model(base_model, lora_config)
        self.projection = nn.Linear(self.bert.config.hidden_size, 128)

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, **kwargs):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs  # allow future extensions
        )

        cls_embedding = outputs.last_hidden_state[:, 0, :]
        return self.projection(cls_embedding)

In [33]:
class TitleParagraphDataset(Dataset):
    def __init__(self, data, tokenizer, positive_prob=0.5, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.positive_prob = positive_prob

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        title = item["title"]
        paragraph = item["paragraph"]

        if torch.rand(1).item() < self.positive_prob:
            # Positive sample
            label = 1
            text = paragraph
        else:
            # Negative sample
            rand_idx = torch.randint(0, len(self.data), (1,)).item()
            while self.data[rand_idx]["title"] == title:
                rand_idx = torch.randint(0, len(self.data), (1,)).item()
            text = self.data[rand_idx]["paragraph"]
            label = 0

        title_encoding = self.tokenizer(
            title,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        text_encoding = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "title_input_ids": title_encoding["input_ids"].squeeze(0),
            "title_attention_mask": title_encoding["attention_mask"].squeeze(0),
            "text_input_ids": text_encoding["input_ids"].squeeze(0),
            "text_attention_mask": text_encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

In [34]:
import json
topic_data = json.load(open("fotech_output.json"))

documents = list(topic_data.values())
total_correct = 0
output_data = []

# Flatten paragraphs
all_paragraphs = []
for doc in documents:
    title = doc["title"]
    for para in doc["text"]:
        all_paragraphs.append({
            "title": title,
            "paragraph": para
        })

In [35]:
from sklearn.model_selection import train_test_split

train_data, temp_data = train_test_split(all_paragraphs, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

print(f"Train size: {len(train_data)}")
print(f"Val size: {len(val_data)}")
print(f"Test size: {len(test_data)}")

Train size: 751
Val size: 94
Test size: 94


In [36]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AdamW, get_scheduler
from sklearn.model_selection import train_test_split
import json
from tqdm import tqdm

tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [37]:
train_dataset = TitleParagraphDataset(train_data, tokenizer)
val_dataset = TitleParagraphDataset(val_data, tokenizer)
test_dataset = TitleParagraphDataset(test_data, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = SentenceBERTContrastive().to(device)

model = TADA().to(device)
# model = TADAWithAlignment().to(device)

weight_path = 'tada_adaptor_epoch_3.pt'
# weight_path = 'tada_alignment_epoch_3.pt'
# load weights
model.load_state_dict(torch.load(weight_path), strict=False)

lora_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "value"]
)

model= get_peft_model(model, lora_config)

optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
criterion = nn.BCEWithLogitsLoss()

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]



In [39]:
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for batch in progress_bar:
        optimizer.zero_grad()

        # Training inputs
        title_input_ids = batch["title_input_ids"].to(device)
        title_attention_mask = batch["title_attention_mask"].to(device)
        text_input_ids = batch["text_input_ids"].to(device)
        text_attention_mask = batch["text_attention_mask"].to(device)
        labels = batch["label"].unsqueeze(1).to(device)

        # Forward
        title_embeds = model(title_input_ids, title_attention_mask)
        text_embeds = model(text_input_ids, text_attention_mask)
        logits = (title_embeds * text_embeds).sum(dim=-1, keepdim=True)

        # Loss & optimize
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        # Train metrics
        total_loss += loss.item()
        preds = torch.sigmoid(logits) > 0.5
        correct += (preds == labels.bool()).sum().item()
        total += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), acc=correct / total)

    avg_train_loss = total_loss / len(train_loader)
    train_acc = correct / total

    # ---------------- Validation ----------------
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in val_loader:
            # Validation inputs
            title_input_ids = batch["title_input_ids"].to(device)
            title_attention_mask = batch["title_attention_mask"].to(device)
            text_input_ids = batch["text_input_ids"].to(device)
            text_attention_mask = batch["text_attention_mask"].to(device)
            labels = batch["label"].unsqueeze(1).to(device)

            # Forward
            title_embeds = model(title_input_ids, title_attention_mask)
            text_embeds = model(text_input_ids, text_attention_mask)
            logits = (title_embeds * text_embeds).sum(dim=-1, keepdim=True)

            # Loss
            loss = criterion(logits, labels)
            val_loss += loss.item()
            preds = torch.sigmoid(logits) > 0.5
            val_correct += (preds == labels.bool()).sum().item()
            val_total += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total

    print(f"Epoch {epoch+1}: train_loss = {avg_train_loss:.4f}, train_acc = {train_acc:.4f}, val_loss = {avg_val_loss:.4f}, val_acc = {val_acc:.4f}")

print("Training + validation complete ✅")


Epoch 1/3 [Train]: 100%|██████████| 47/47 [00:47<00:00,  1.01s/it, acc=0.47, loss=0.771]


Epoch 1: train_loss = 2.1551, train_acc = 0.4700, val_loss = 0.8739, val_acc = 0.5319


Epoch 2/3 [Train]: 100%|██████████| 47/47 [00:46<00:00,  1.00it/s, acc=0.506, loss=0.903]


Epoch 2: train_loss = 0.8056, train_acc = 0.5060, val_loss = 0.6732, val_acc = 0.5851


Epoch 3/3 [Train]: 100%|██████████| 47/47 [00:47<00:00,  1.01s/it, acc=0.523, loss=0.751]


Epoch 3: train_loss = 0.7416, train_acc = 0.5233, val_loss = 0.7757, val_acc = 0.5000
Training + validation complete ✅


In [40]:
model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0

with torch.no_grad():
    for batch in test_loader:
        # Test inputs
        title_input_ids = batch["title_input_ids"].to(device)
        title_attention_mask = batch["title_attention_mask"].to(device)
        text_input_ids = batch["text_input_ids"].to(device)
        text_attention_mask = batch["text_attention_mask"].to(device)
        labels = batch["label"].unsqueeze(1).to(device)

        # Forward
        title_embeds = model(title_input_ids, title_attention_mask)
        text_embeds = model(text_input_ids, text_attention_mask)
        logits = (title_embeds * text_embeds).sum(dim=-1, keepdim=True)

        # Loss
        loss = criterion(logits, labels)
        test_loss += loss.item()
        preds = torch.sigmoid(logits) > 0.5
        test_correct += (preds == labels.bool()).sum().item()
        test_total += labels.size(0)

avg_test_loss = test_loss / len(test_loader)
test_acc = test_correct / test_total

print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

Test Loss: 0.7136, Test Accuracy: 0.5426
