# Classification with CLIP Text portion (https://github.com/openai/CLIP)

uses masked self-attention Transformer as a text encoder

In [1]:
# imports
import torch
import numpy as np

SEED = 42
torch.manual_seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fe0851acb90>

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [3]:
import clip
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [4]:
model, preprocess = clip.load("ViT-B/32", jit=True, device=device)
del model

# Load Dataset

In [5]:
import pandas as pd
import os

from torch.utils.data import Dataset, DataLoader
from PIL import Image

In [6]:
class FakedditDataset(Dataset):
    """Subset of fake news dataset from """

    def __init__(self, dataset, root_dir, image_preprocess=None):
        """
        Args:
            dataset (string): Path to the csv file or a pandas DF
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        if type(dataset) is str:
            self.dataset = pd.read_csv(dataset)
        else:
            self.dataset = dataset
        self.root_dir = root_dir
        self.image_preprocess = image_preprocess

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        text = self.dataset.iloc[idx, 0]
        img_name = os.path.join(self.root_dir, f"{self.dataset.iloc[idx, 1]}.jpg")
        image = Image.open(img_name)
        if self.image_preprocess:
            image = self.image_preprocess(image.convert("RGB"))
            
        label = torch.zeros(6)
        label[self.dataset.iloc[idx, 2]] = 1
        
        return image, text, label

In [7]:
batch_size = 32

trainset = FakedditDataset('train_clean.csv', 'data', image_preprocess=preprocess)
testset = FakedditDataset('test_clean.csv', 'data', image_preprocess=preprocess)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2)

In [8]:
dataiter = iter(trainloader)
images, texts, labels = next(dataiter)
print(images[0].shape, len(texts[0]), labels[0].shape)

torch.Size([3, 224, 224]) 36 torch.Size([6])


# Model Definition

In [9]:
import torch
import torch.nn as nn

In [10]:
class CLIPClassifier(nn.Module):
    def __init__(self, device='cpu') -> None:
        super().__init__()
        self.device = device
        
        self.clip_layer, _ = clip.load("ViT-B/32", jit=True, device=device) # Changed JIT to True for just inference
        # output of clip is 512
        # cat image and text for 1024
        self.fc1 = nn.Linear(512, 512, device=device)
        self.fc2 = nn.Linear(512, 128, device=device)
        self.fc3 = nn.Linear(128, 6, device=device)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, text): # remove the image portion
        text_features = self.clip_layer.encode_text(text).float()

        x = self.relu(self.fc1(text_features))
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# classifier = CLIPClassifier(device=device)
# classifier.device

# Training

In [11]:
# preds and y are one-hot encoded
def binary_accuracy(preds, y):
    preds_label = torch.argmax(preds, dim=1)
    y_label = torch.argmax(y, dim=1)
    
    correct = torch.sum(preds_label==y_label).item()
    acc = correct / len(y_label)
    return acc

In [12]:
import math
import copy

# define the initial learning rate here
learning_rate = 1e-4
n_epochs = 50 # how many epochs to run
momentum = 0.9
patience = 10 # number of times to observe worsening val set error before giving up
MODEL_LOCATION = 'models/cliptextclassifier/'
MODEL_VERSION = '1'
MODEL_NAME = 'cliptextclassifier.pth'
FULL_LOCATION = os.path.join(MODEL_LOCATION, MODEL_VERSION)
MODEL_PATH = os.path.join(FULL_LOCATION, MODEL_NAME)
os.makedirs(FULL_LOCATION, exist_ok=True)

# define loss function and model
criterion = nn.BCEWithLogitsLoss()
model = CLIPClassifier(device=device)
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate, momentum=momentum)

trainval_lossacc = {'train_loss':[], 'train_acc':[],'valid_loss':[],'valid_acc':[]}

min_val_loss = math.inf
epoch_no_improv = 0

cur_epoch = 1

  if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):


In [13]:
# # load from checkpoint if it exists
# checkpoint = torch.load(MODEL_PATH, map_location=torch.device('cpu'))

# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# cur_epoch = checkpoint['epoch']
# trainval_lossacc = checkpoint['loss']

In [14]:
for epoch in range(cur_epoch,n_epochs):  # loop over the dataset multiple times

    # Training
    epoch_loss = 0.0
    epoch_acc = 0.0
    running_loss = 0.0
    model.train()
    for i, data in enumerate(trainloader, 0):
        _, texts, labels = data
        
        text_tokens = clip.tokenize(texts, truncate=True).to(device) # truncate: some titles are longer than 77, but I think there is more than enough context in 77 words
        labels = labels.float().to(device)

        # zero parameter gradients
        optimizer.zero_grad()

        # Forward
        output = model(text_tokens)

        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        epoch_loss += loss.item()
        epoch_acc += binary_accuracy(output, labels)
        if i % 200 == 199:  # print every 200 mini-batches
            print('[Epoch %d, Step %5d] loss: %.3f' %
                  (epoch, i + 1, running_loss / 200))
            running_loss = 0.0
    train_loss, train_acc = epoch_loss / len(trainloader), epoch_acc / len(trainloader)

    trainval_lossacc['train_loss'].append(train_loss)
    trainval_lossacc['train_acc'].append(train_acc)

    # Evaluate with Test dataset
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(testloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            _, texts, labels = data
            text_tokens = clip.tokenize(texts, truncate=True).to(device) # truncate: some titles are longer than 77, but I think there is more than enough context in 77 words
            # print(text_tokens.shape, text_tokens.dtype)
            labels = labels.float().to(device)
            # Forward 
            output = model(text_tokens)
            
            # Compute the loss using the final output
            loss = criterion(output, labels)

            epoch_loss += loss.item()
            epoch_acc += binary_accuracy(output, labels)
    
    valid_loss, valid_acc = epoch_loss / len(testloader), epoch_acc / len(testloader)

    trainval_lossacc['valid_loss'].append(valid_loss)
    trainval_lossacc['valid_acc'].append(valid_acc)
    # Showing statistics
    print(f'[{epoch}] \tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'[{epoch}] \t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

    # Early stopping condition
    # https://stackoverflow.com/questions/60200088/how-to-make-early-stopping-in-image-classification-pytorch
    if valid_loss < min_val_loss:
        min_val_loss = valid_loss
        epoch_no_improv = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': trainval_lossacc,
            }, MODEL_PATH)
        
        print(f'Min val loss {min_val_loss:.3f}')
    else:
        epoch_no_improv += 1
        if epoch_no_improv >= patience:
            print('Early Stopping')
            break
            # os.makedirs(f'drive/Shareddrives/MultimodalNews/models/{MODEL_VERSION}/', exist_ok=True)
            # torch.save(best_model, f"drive/Shareddrives/MultimodalNews/models/{MODEL_VERSION}/clipclassifier.pth")
        print(f"no improvement = {epoch_no_improv}")
    print()

print('Finished Training')

 does not have profile information (Triggered internally at ../torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


[Epoch 1, Step   200] loss: 0.671
[Epoch 1, Step   400] loss: 0.636
[Epoch 1, Step   600] loss: 0.572
[1] 	Train Loss: 0.602 | Train Acc: 31.28%
[1] 	 Val. Loss: 0.475 |  Val. Acc: 39.52%
Min val loss 0.475

[Epoch 2, Step   200] loss: 0.447
[Epoch 2, Step   400] loss: 0.410
[Epoch 2, Step   600] loss: 0.392
[2] 	Train Loss: 0.410 | Train Acc: 39.62%
[2] 	 Val. Loss: 0.381 |  Val. Acc: 39.51%
Min val loss 0.381

[Epoch 3, Step   200] loss: 0.383
[Epoch 3, Step   400] loss: 0.383
[Epoch 3, Step   600] loss: 0.380
[3] 	Train Loss: 0.382 | Train Acc: 39.52%
[3] 	 Val. Loss: 0.375 |  Val. Acc: 39.52%
Min val loss 0.375

[Epoch 4, Step   200] loss: 0.379
[Epoch 4, Step   400] loss: 0.380
[Epoch 4, Step   600] loss: 0.381
[4] 	Train Loss: 0.379 | Train Acc: 39.62%
[4] 	 Val. Loss: 0.374 |  Val. Acc: 39.51%
Min val loss 0.374

[Epoch 5, Step   200] loss: 0.381
[Epoch 5, Step   400] loss: 0.379
[Epoch 5, Step   600] loss: 0.377
[5] 	Train Loss: 0.378 | Train Acc: 39.61%
[5] 	 Val. Loss: 0.374 

# Final Evaluation

In [20]:
# load from checkpoint if it exists
checkpoint = torch.load(MODEL_PATH, map_location=torch.device('cpu'))

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
cur_epoch = checkpoint['epoch']
trainval_lossacc = checkpoint['loss']

In [21]:
print(f"final train loss: {trainval_lossacc['train_loss'][-1]:.5f}")
print(f"final train acc: {trainval_lossacc['train_acc'][-1]:.5f}")
print()
print(f"final valid loss: {trainval_lossacc['valid_loss'][-1]:.5f}")
print(f"final valid acc: {trainval_lossacc['valid_acc'][-1]:.5f}")

final train loss: 0.24503
final train acc: 0.66633

final valid loss: 0.27108
final valid acc: 0.63299


In [25]:
print(f"[{cur_epoch}] \tTrain Loss: {trainval_lossacc['train_loss'][-1]:.5f} | Train Acc: {trainval_lossacc['train_acc'][-1]*100:.5f}%")
print(f"[{cur_epoch}] \t Val. Loss: {trainval_lossacc['valid_loss'][-1]:.5f} |  Val. Acc: {trainval_lossacc['valid_acc'][-1]*100:.5f}%")
# [30] 	Train Loss: 0.24503 | Train Acc: 66.63274%
# [30] 	 Val. Loss: 0.27108 |  Val. Acc: 63.29861%

[30] 	Train Loss: 0.24503 | Train Acc: 66.63274%
[30] 	 Val. Loss: 0.27108 |  Val. Acc: 63.29861%

