# DISTILBERT

# Pipeline

In [None]:
!gdown 1GnUYDYrpc3H3EVpem8sCBOZ1ZvcjI4oG

Downloading...
From: https://drive.google.com/uc?id=1GnUYDYrpc3H3EVpem8sCBOZ1ZvcjI4oG
To: /content/insert_chars_1.json
100% 55.8M/55.8M [00:00<00:00, 70.2MB/s]


In [37]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import pipeline
from tqdm import tqdm
import time

# Check if GPU is available
device = 0 if torch.cuda.is_available() else -1
BATCH_SIZE = 128  # Optimal batch size

# Initialize the question answering pipeline with GPU support
model_name = "distilbert-base-uncased-distilled-squad"
question_answerer = pipeline("question-answering", model=model_name, device=device, batch_size=BATCH_SIZE, framework="pt")

# Load the SQuAD dataset JSON file
with open('insert_chars_1.json', 'r') as file:
    squad_data = json.load(file)

class SQuADDataset(Dataset):
    def __init__(self, squad_data):
        self.samples = []
        for item in squad_data:
            for paragraph in item['data']['paragraphs']:
                context = paragraph['context']
                new_context = paragraph.get('new_context', context)
                for qa in paragraph['qas']:
                    self.samples.append({
                        'id': qa['id'],
                        'question': qa['question'],
                        'context': context,
                        'new_context': new_context,
                        'orig_answer': qa['answers']
                    })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    questions = [sample['question'] for sample in batch]
    contexts = [sample['context'] for sample in batch]
    new_contexts = [sample['new_context'] for sample in batch]
    ids = [sample['id'] for sample in batch]
    orig_answers = [sample['orig_answer'] for sample in batch]
    return questions, contexts, new_contexts, ids, orig_answers

# Create dataset and dataloader
dataset = SQuADDataset(squad_data)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=16, pin_memory=True)

# Function to perform batch question answering
def batch_question_answer(question_answerer, questions, contexts):
    batch = [{'question': q, 'context': c} for q, c in zip(questions, contexts)]
    return question_answerer(batch)

results = []

start_time = time.time()

# Processing loop
for questions, contexts, new_contexts, ids, orig_answers in tqdm(dataloader):
    batch_start_time = time.time()

    # Process context and new context in parallel
    context_results = []
    new_context_results = []

    def process_contexts(contexts_batch, results_list):
        results_list.extend(batch_question_answer(question_answerer, questions, contexts_batch))

    context_thread = torch.jit.fork(process_contexts, contexts, context_results)
    new_context_thread = torch.jit.fork(process_contexts, new_contexts, new_context_results)

    torch.jit.wait(context_thread)
    torch.jit.wait(new_context_thread)

    for idx, qa_id in enumerate(ids):
        if 'answer' in context_results[idx] and 'answer' in new_context_results[idx]:
            results.append({
                'id': qa_id,
                'question': questions[idx],
                'new_context': new_contexts[idx],
                'orig_answer': orig_answers[idx],
                'context_answer': context_results[idx]['answer'],
                'new_context_answer': new_context_results[idx]['answer'],
                'score': new_context_results[idx]['score'],
                'start': new_context_results[idx]['start'],
                'end': new_context_results[idx]['end']
            })

    batch_end_time = time.time()
    print(f"Batch processing time: {batch_end_time - batch_start_time:.2f} seconds")

end_time = time.time()
print(f"Total processing time: {end_time - start_time:.2f} seconds")

# Optionally, save results to a new JSON file
with open('squad_results.json', 'w') as file:
    json.dump(results, file, indent=2)


  0%|          | 1/1019 [00:08<2:25:19,  8.56s/it]

Batch processing time: 7.28 seconds


  0%|          | 2/1019 [00:16<2:18:18,  8.16s/it]

Batch processing time: 7.87 seconds


  0%|          | 3/1019 [00:23<2:10:39,  7.72s/it]

Batch processing time: 7.18 seconds


  0%|          | 4/1019 [00:31<2:08:45,  7.61s/it]

Batch processing time: 7.44 seconds


  0%|          | 4/1019 [00:37<2:39:40,  9.44s/it]


KeyboardInterrupt: 