In [1]:
!pip install adapters

Collecting adapters
  Downloading adapters-1.1.0-py3-none-any.whl.metadata (16 kB)
Collecting transformers~=4.47.1 (from adapters)
  Downloading transformers-4.47.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Downloading adapters-1.1.0-py3-none-any.whl (293 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m293.4/293.4 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading transformers-4.47.1-py3-none-any.whl (10.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m128.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers, adapters
  Attempting uninstall: transformers
    Found existing installation: transformers 4.48.3
    Uninstalling transformers-4.48.3:
      Successfully uninstalled transformers-4.48.3
Successfully installed adapters-1.1.0 transformers-4.47.1


## Alignment Code from https://github.com/Helw150/tada/blob/main/models.py

In [13]:
from torch.autograd import Function

class GradientReversal(Function):
    @staticmethod
    def forward(ctx, i):
        return i

    @staticmethod
    def backward(ctx, grad_output):
        return -1 * grad_output

In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
from peft import get_peft_model, LoraConfig, TaskType

# Optional: Dummy Gradient Reversal Layer implementation
class GradientReversal(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return -grad_output

# Main model with LoRA adapters + alignment loss + critic
class TADAWithAlignment(nn.Module):
    def __init__(self, model_name='dccuchile/bert-base-spanish-wwm-uncased'):
        super().__init__()

        # Base model with LoRA adapters
        base_model = BertModel.from_pretrained(model_name)
        for param in base_model.parameters():
            param.requires_grad = False  # Freeze base model

        lora_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["query", "value"]
        )
        self.bert = get_peft_model(base_model, lora_config)

        # Critic network
        hidden_size = self.bert.config.hidden_size
        self.critic_transform = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=12)
        self.critic_score = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )

    @torch.no_grad()
    def produce_original_embeddings(self, input_ids, attention_mask):
        self.eval()
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )
        hidden_states = outputs.hidden_states[-1]  # Last hidden layer
        hidden_states = hidden_states * attention_mask.unsqueeze(-1)
        self.train()
        return hidden_states

    def critic(self, embedding):
        mask = embedding.sum(-1) != 0  # padding mask
        cls_token = self.critic_transform(
            embedding.permute(1, 0, 2), src_key_padding_mask=mask
        )[0, :, :]  # Take first token (CLS)
        scores = self.critic_score(cls_token)
        return scores.mean()

    def forward(self, input_ids, attention_mask, original_embedding=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )
        hidden_states = outputs.hidden_states[-1] * attention_mask.unsqueeze(-1)

        if original_embedding is not None:
            # Adversarial alignment mode
            hidden_states_reversed = GradientReversal.apply(hidden_states)
            alignment_loss = (
                (original_embedding[:, 0, :] - hidden_states_reversed[:, 0, :])
                .square()
                .sum(1)
                .mean()
            )
            critic_loss = self.critic(hidden_states_reversed) - self.critic(original_embedding)
            total_alignment_loss = critic_loss - alignment_loss
            return total_alignment_loss
        else:
            # Inference mode (just CLS token for downstream task)
            return hidden_states[:, 0, :]  # Return [CLS] token

In [22]:
# Instantiate the tokenizer and model.
tokenizer = BertTokenizer.from_pretrained('dccuchile/bert-base-spanish-wwm-uncased')
model = TADAWithAlignment()

# Define optimizer and loss function.
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-uncased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
from torch.utils.data import Dataset, DataLoader
# Custom Dataset to hold text pairs.
class TextPairDataset(Dataset):
    def __init__(self, data_pairs):
        """
        data_pairs: List of tuples (original_text, transformed_text)
        """
        self.data = data_pairs

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        original_text, transformed_text = self.data[idx]
        return {'text1': original_text, 'text2': transformed_text}

# Collate function to batch and tokenize samples.
def collate_fn(batch, tokenizer, max_length=512):
    texts1 = [item['text1'] for item in batch]
    texts2 = [item['text2'] for item in batch]
    encoding1 = tokenizer(texts1, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    encoding2 = tokenizer(texts2, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    return encoding1, encoding2

In [24]:
import json
import random
# Load JSON dataset.
with open('top_n_chilean_examples.json', 'r', encoding='utf-8') as f:
    data_json = json.load(f)

# Extract paired texts: use the "original_text" (or the key as a fallback) and "transformed_text"
data_pairs = []
for key, value in data_json.items():
    original_text = value.get('original_text', key)
    transformed_text = value.get('transformed_text', None)
    if transformed_text is not None:
        data_pairs.append((original_text, transformed_text))

# Shuffle and split into train and eval sets (e.g., 80/20 split)
random.shuffle(data_pairs)
split_idx = int(len(data_pairs) * 0.8)
train_pairs = data_pairs[:split_idx]
eval_pairs = data_pairs[split_idx:]

# Create dataset objects.
train_dataset = TextPairDataset(train_pairs)
eval_dataset = TextPairDataset(eval_pairs)

In [25]:
# Create DataLoaders.
batch_size = 8
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: collate_fn(batch, tokenizer)
)
eval_loader = DataLoader(
    eval_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, tokenizer)
)

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

TADAWithAlignment(
  (bert): PeftModelForFeatureExtraction(
    (base_model): LoraModel(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(31002, 768, padding_idx=1)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_feat

In [27]:
lambda_alignment = 0.1
num_epochs = 3
criterion = nn.MSELoss()

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0.0

    for encoding1, encoding2 in train_loader:
        input_ids1 = encoding1['input_ids'].to(device)
        attention_mask1 = encoding1['attention_mask'].to(device)
        input_ids2 = encoding2['input_ids'].to(device)
        attention_mask2 = encoding2['attention_mask'].to(device)

        optimizer.zero_grad()

        # Step 1: Original frozen embeddings
        with torch.no_grad():
            original_embedding1 = model.produce_original_embeddings(input_ids1, attention_mask1)
            original_embedding2 = model.produce_original_embeddings(input_ids2, attention_mask2)

        # Step 2: Alignment loss (adversarial)
        alignment_loss1 = model(input_ids1, attention_mask1, original_embedding=original_embedding1)
        alignment_loss2 = model(input_ids2, attention_mask2, original_embedding=original_embedding2)

        # Step 3: Contrastive loss
        cls1 = model(input_ids1, attention_mask1)
        cls2 = model(input_ids2, attention_mask2)
        contrastive_loss = criterion(cls1, cls2)

        # Step 4: Total loss
        total_loss = contrastive_loss + lambda_alignment * (alignment_loss1 + alignment_loss2)
        total_loss.backward()
        optimizer.step()

        total_train_loss += total_loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}")

    model.eval()
    total_eval_loss = 0.0
    with torch.no_grad():
        for encoding1, encoding2 in eval_loader:
            input_ids1 = encoding1['input_ids'].to(device)
            attention_mask1 = encoding1['attention_mask'].to(device)
            input_ids2 = encoding2['input_ids'].to(device)
            attention_mask2 = encoding2['attention_mask'].to(device)
            cls1 = model(input_ids1, attention_mask1)
            cls2 = model(input_ids2, attention_mask2)
            loss = criterion(cls1, cls2)
            total_eval_loss += loss.item()
    avg_eval_loss = total_eval_loss / len(eval_loader)
    print(f"Eval Loss: {avg_eval_loss:.4f}")

    torch.save(model.state_dict(), f'tada_alignment_epoch_{epoch+1}.pt')
    print(f"Model saved to tada_alignment_epoch_{epoch+1}.pt")

Epoch 1 - Train Loss: -9.9349
Eval Loss: 0.0010
Model saved to tada_alignment_epoch_1.pt
Epoch 2 - Train Loss: -45.4152
Eval Loss: 0.0003
Model saved to tada_alignment_epoch_2.pt
Epoch 3 - Train Loss: -132.4428
Eval Loss: 0.0002
Model saved to tada_alignment_epoch_3.pt


In [30]:
text = data_pairs[0][0]

In [31]:
model.eval()
with torch.no_grad():
    input_ids = tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)

    # Inference mode (no alignment loss)
    cls_embedding = model(input_ids, attention_mask)

In [34]:
cls_embedding.shape

torch.Size([1, 768])