In [None]:
!pip install transformers

# CLIP

Using a pretrained clip for zero-shot classification.

In [None]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from matplotlib import pyplot as plt
import requests

# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# URLs for cat and dog images
cat_url = "https://ds4440.baulab.info/data/cat.png"
dog_url = "https://ds4440.baulab.info/data/dog.png"

# Download and preprocess the images
def preprocess_image(url):
    image = Image.open(requests.get(url, stream=True).raw)
    image = image.resize((224, 224))
    # display the image
    plt.imshow(image)
    plt.axis("off")
    plt.show()
    inputs = processor(images=image, return_tensors="pt")
    return inputs

# Prepare the text inputs, one for each class
cat_text = "a photo of a cat"
dog_text = "a photo of a dog"
text_inputs = processor(text=[cat_text, dog_text], padding=True, return_tensors="pt")

# Get the text features
text_features = model.get_text_features(
    input_ids=text_inputs["input_ids"],
    attention_mask=text_inputs["attention_mask"]
)

# For each image, get the image features and check similarity with the text
for url in [cat_url, dog_url]:
    inputs = preprocess_image(url)
    image_features = model.get_image_features(pixel_values=inputs["pixel_values"])

    # Compute the similarity scores
    cat_similarity = (image_features @ text_features.T).squeeze()[0].item()
    dog_similarity = (image_features @ text_features.T).squeeze()[1].item()

    # Print the similarity scores
    print(f"Cat Similarity: {cat_similarity}")
    print(f"Dog Similarity: {dog_similarity}")

    # Classify the image based on the higher similarity score
    if cat_similarity > dog_similarity:
        print("The image is classified as a cat.")
    else:
        print("The image is classified as a dog.")

# BERT

A brief demonstration of fine-tuning a sentence representation on a problem.

In [None]:
# Install required libraries
!pip install torch transformers datasets

In [None]:
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
from tqdm.notebook import tqdm

# Load the SST-2 dataset
dataset = load_dataset('glue', 'sst2')

# Load the BERT tokenizer and model
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Freeze the BERT model parameters
for param in model.bert.parameters():
    param.requires_grad = False

# Preprocess the dataset
def preprocess_function(examples):
    return tokenizer(examples['sentence'], truncation=True)

encoded_dataset = dataset.map(preprocess_function, batched=True)

# Collate function to dynamically pad sequences in a batch
def collate_fn(batch):
    input_ids = [torch.tensor(item['input_ids']) for item in batch]
    attention_mask = [torch.tensor(item['attention_mask']) for item in batch]
    labels = torch.tensor([item['label'] for item in batch])
    input_ids = pad_sequence(input_ids, batch_first=True)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

# Create DataLoaders with collate function
train_dataloader = DataLoader(encoded_dataset['train'], batch_size=16, shuffle=True, collate_fn=collate_fn)
eval_dataloader = DataLoader(encoded_dataset['validation'], batch_size=64, collate_fn=collate_fn)

# Set up optimizer and learning rate
optimizer = AdamW(model.classifier.parameters(), lr=2e-5)

# Fine-tune the classifier part of the model
model.classifier.train()
model.bert.eval()
num_epochs = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

progress_bar = tqdm(range(num_epochs * len(train_dataloader)), desc="Training")

for epoch in range(num_epochs):
    epoch_loss = 0.0
    epoch_length = 0
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()
        epoch_length += 1
        progress_bar.update(1)
        progress_bar.set_postfix({"Epoch": epoch+1, "Loss": epoch_loss / epoch_length})

    print(f"Epoch {epoch+1} Loss: {epoch_loss / epoch_length:.4f}")

progress_bar.close()

# Evaluate the model
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="Evaluation"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Evaluation Accuracy: {accuracy:.4f}")



In [None]:
# Evaluate student-defined examples
student_examples = [
    "This movie was predictably amazing.",
    "The movie was amazingly predictable.",
    "But I didn't enjoy the book at all.",
    "All didn't enjoy the book, but I did.",
]

student_inputs = tokenizer(student_examples, padding=True, truncation=True, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**student_inputs)
    logits = outputs.logits

    predictions = torch.argmax(logits, dim=1)

# Print the predictions
for example, prediction in zip(student_examples, predictions):
    sentiment = "Positive" if prediction.item() == 1 else "Negative"
    print(f"Example: {example}")
    print(f"Predicted Sentiment: {sentiment}")
    print()