In [1]:
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import tqdm
import faiss
import pandas as pd
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cpu


In [2]:
class ClaimsDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        claim = self.data.iloc[idx]['summary_sentence']
        passage = self.data.iloc[idx]['text_chunk']
        group_id = self.data.iloc[idx]['book_num']
        
        claim_inputs = self.tokenizer(claim, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        passage_inputs = self.tokenizer(passage, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        return claim_inputs, passage_inputs, group_id
    
    import torch
    from torch.utils.data import Sampler
    import numpy as np

class GroupedSampler(Sampler):
    def __init__(self, data_source, group_ids, batch_size):
        self.data_source = data_source
        self.group_ids = group_ids
        self.batch_size = batch_size
        self.grouped_indices = self.group_indices_by_id()

    def group_indices_by_id(self):
        grouped_indices = {}
        for idx, group_id in enumerate(self.group_ids):
            if group_id not in grouped_indices:
                grouped_indices[group_id] = []
            grouped_indices[group_id].append(idx)
        return grouped_indices

    def __iter__(self):
        batches = []
        for group in self.grouped_indices.values():
            print(f"Original group indices: {group}")  # Debug print
            np.random.shuffle(group)  # Shuffle within the group
            print(f"Shuffled group indices: {group}")  # Debug print
            for i in range(0, len(group), self.batch_size):
                batch = group[i:i+self.batch_size]
                if len(batch) == self.batch_size:
                    batches.append(batch)
                    print(f"Created batch: {batch}")  # Debug print
        np.random.shuffle(batches)  # Shuffle the batches
        print(f"Final shuffled batches: {batches}")  # Debug print
        return iter(batches)


    def __len__(self):
        return len(self.data_source) // self.batch_size

In [3]:
def fine_tune(model, train_loader, val_loader, epochs=3, lr=5e-5, batch_size=32, device='cuda', temperature=1.0, label_smoothing=0.1, warmup_steps=2000):
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    model.to(device)

    # Initialize logging
    log_interval = 20
    training_loss_log = []
    validation_loss_log = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        step = 0

        for claim_inputs, passage_inputs, _ in tqdm.tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            optimizer.zero_grad()
            claim_inputs = {k: v.squeeze(1).to(device) for k, v in claim_inputs.items()}
            passage_inputs = {k: v.squeeze(1).to(device) for k, v in passage_inputs.items()}
            claim_outputs = model(**claim_inputs).last_hidden_state.mean(dim=1)
            passage_outputs = model(**passage_inputs).last_hidden_state.mean(dim=1)
            scores = torch.einsum("id, jd->ij", claim_outputs / temperature, passage_outputs)
            bsz = len(claim_inputs['input_ids'])
            labels = torch.arange(0, bsz, dtype=torch.long, device=device)
            loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=label_smoothing)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            
            # Log training loss every 20 steps
            if step % log_interval == 0 and step > 0:
                training_loss_log.append((epoch, step, loss.item()))
                print(f"Epoch {epoch+1}, Step {step}, Training Loss: {loss.item()}")
            step += 1

        print(f"Epoch {epoch+1}, Average Training Loss: {total_loss/len(train_loader)}")

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for claim_inputs, passage_inputs, _ in tqdm.tqdm(val_loader, desc=f"Validation Epoch {epoch+1}"):
                claim_inputs = {k: v.squeeze(1).to(device) for k, v in claim_inputs.items()}
                passage_inputs = {k: v.squeeze(1).to(device) for k, v in passage_inputs.items()}
                claim_outputs = model(**claim_inputs).last_hidden_state.mean(dim=1)
                passage_outputs = model(**passage_inputs).last_hidden_state.mean(dim=1)
                scores = torch.einsum("id, jd->ij", claim_outputs / temperature, passage_outputs)
                bsz = len(claim_inputs['input_ids'])
                labels = torch.arange(0, bsz, dtype=torch.long, device=device)
                loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=label_smoothing)
                val_loss += loss.item()
                validation_loss_log.append((epoch, step, loss.item()))
        
        print(f"Epoch {epoch+1}, Validation Loss: {val_loss/len(val_loader)}")
    
    model.save_pretrained(f'fine_tuned_model_batch_size={batch_size}')

    # Save loss logs to files
    training_loss_filename = f'training_loss_log_batch_size_{batch_size}.txt'
    validation_loss_filename = f'validation_loss_log_batch_size_{batch_size}.txt'

    with open(training_loss_filename, 'w') as f:
        for epoch, step, loss in training_loss_log:
            f.write(f"Epoch: {epoch+1}, Step: {step}, Training Loss: {loss}\n")

    with open(validation_loss_filename, 'w') as f:
        for epoch, step, loss in validation_loss_log:
            f.write(f"Epoch: {epoch+1}, Step: {step}, Validation Loss: {loss}\n")



In [4]:
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
import tqdm

def main(batch_size):
    data = pd.read_csv('Data/mapped_summaries_l3.csv')
    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
    model = AutoModel.from_pretrained('facebook/contriever')

    # Calculate the indices for the splits
    train_size = int(0.8 * len(data))
    val_size = int(0.1 * len(data))

    # Split the data linearly
    train_data = data.iloc[:train_size]
    val_data = data.iloc[train_size:train_size + val_size]
    test_data = data.iloc[train_size + val_size:]

    # Save the datasets into CSV files
    train_data.to_csv('train_data.csv', index=False)
    val_data.to_csv('val_data.csv', index=False)
    test_data.to_csv('test_data.csv', index=False)

    train_dataset = ClaimsDataset(data=train_data, tokenizer=tokenizer)
    val_dataset = ClaimsDataset(data=val_data, tokenizer=tokenizer)

    train_group_ids = train_data['book_num'].tolist()
    train_sampler = GroupedSampler(data_source=train_dataset, group_ids=train_group_ids, batch_size=batch_size)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)
    
    val_group_ids = val_data['book_num'].tolist()
    val_sampler = GroupedSampler(data_source=val_dataset, group_ids=val_group_ids, batch_size=batch_size)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler)

    fine_tune(model, train_loader, val_loader, epochs=10, lr=5e-5, batch_size=batch_size, device=device)
    tokenizer.save_pretrained(f'fine_tuned_model_batch_size={batch_size}')


main(32)

Training Epoch 1:   0%|          | 0/5447 [00:00<?, ?it/s]

Original group indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 

Training Epoch 1:   0%|          | 0/5447 [00:03<?, ?it/s]


KeyboardInterrupt: 