In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader, Dataset

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("stanford/BioMedLM")
model = AutoModelForCausalLM.from_pretrained("stanford/BioMedLM")

# Move model to the selected device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Freeze the weights of the PubMedLM model
for param in model.parameters():
    param.requires_grad = False

# Define a dataset for binary classification
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        return input_ids, attention_mask, torch.tensor(label)

# Define a binary classifier
class BinaryClassifier(nn.Module):
    def __init__(self, hidden_size):
        super(BinaryClassifier, self).__init__()
        self.linear = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, hidden_states):
        # Use the hidden states from the [CLS] token (first token)
        cls_hidden_state = hidden_states[:, 0, :]
        logits = self.linear(cls_hidden_state)
        return self.sigmoid(logits)

# Combine PubMedLM with the classifier
class PubMedLMWithClassifier(nn.Module):
    def __init__(self, language_model, classifier):
        super(PubMedLMWithClassifier, self).__init__()
        self.language_model = language_model
        self.classifier = classifier

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():  # Ensure the language model weights are not updated
            outputs = self.language_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # Use the last layer's hidden states
        logits = self.classifier(hidden_states)
        return logits

# Initialize the classifier
classifier = BinaryClassifier(hidden_size=model.config.hidden_size).to(device)

# Initialize the combined model
combined_model = PubMedLMWithClassifier(model, classifier).to(device)

# Define a simple dataset and dataloader
texts = ["Example text 1", "Example text 2"]  # Replace with your actual texts
labels = [0, 1]  # Replace with your actual labels
dataset = TextDataset(texts, labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-5)  # Only optimize classifier parameters

# Training loop
for epoch in range(3):  # Number of epochs
    combined_model.train()
    for input_ids, attention_mask, labels in dataloader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device).float()

        # Forward pass
        outputs = combined_model(input_ids, attention_mask)
        loss = criterion(outputs.squeeze(), labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

print("Training complete.")
