In [1]:
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [2]:
from transformers import CLIPModel, AutoModel, CLIPProcessor
from preprocessing import fcgr, protein_to_dna
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
#from CLIP.clip import clip
#from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from clip import clip
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

In [4]:
class image_title_dataset(Dataset):
    def __init__(self, processor, images, labels, max_len):
        # Initialize image paths and corresponding texts
        self.images = images
        self.max_len = max_len
        # Tokenize text using CLIP's tokenizer
        self.labels = labels
        self.processor = processor
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        labels_short = self.labels[idx:idx + 1].tolist()[:][:self.max_len]

        tokens = clip.tokenize(self.labels[idx:idx + 1].tolist()[:50], context_length = 256,truncate=True)[:1, :self.max_len]

        if tokens.size()[1] < self.max_len:
            tokens = torch.cat([tokens,
                       torch.zeros(size=(1, self.max_len - tokens.size()[1]))], dim=1).type(
                torch.int)
        #attmask = torch.zeros()
        #pixel_values = torch.transforms

        inputs = self.processor(text=labels_short, images=self.images[idx:idx+1], return_tensors="pt", padding=True)
        inputs['input_ids'] = tokens
        if inputs['attention_mask'].size()[1] > self.max_len:
            inputs['attention_mask'] = inputs['attention_mask'][:1, :self.max_len]
        inputs['attention_mask'] = torch.cat([inputs['attention_mask'], torch.zeros(size=(1, len(inputs['input_ids'][0]) - len(inputs['attention_mask'][0])))], dim=1).type(torch.int)

        #inputs['input_ids'] = torch.cat([inputs['input_ids'], torch.Tensor([[0] * (self.max_len - len(inputs['input_ids'][0]))])], dim=1)
        #inputs['attention_mask'] = torch.cat([inputs['attention_mask'], torch.Tensor([[0] * (self.max_len - len(inputs['attention_mask'][0]))])], dim=1)
        return inputs

In [5]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.requires_grad:
            p.grad.data = p.grad.data.float()

In [6]:
def evaluate_training(model,loader_test, criterion):
    loss_all = []
    with torch.no_grad():
        
        for i, inputs in enumerate(loader_test):
            outputs = model(input_ids=torch.Tensor(inputs['input_ids']).type(torch.int),attention_mask=torch.Tensor(inputs['attention_mask']).squeeze(1), pixel_values=torch.Tensor(inputs['pixel_values']).squeeze(1))
            
            logits_i = outputs.logits_per_image
            logits_t = outputs.logits_per_text
            #probs = logits.softmax(dim=1)

            labels = torch.arange(0, logits_i.shape[0])
            loss_i = criterion(logits_i, labels)
            loss_t = criterion(logits_t, labels)

            loss = (loss_i + loss_t)/2
            loss_all.append(loss.item())

    return np.mean(loss_all)
        
            

In [7]:
def train(model, processor, images, labels, epochs=1):
    #tokenizer = get_or_build_tokenizer(labels)
    criterion = nn.CrossEntropyLoss()
    #for img in images:

    #labeltokens = clip.tokenize(labels.tolist(), 128)

    #labeltokens = np.array([np.array(tokenizer.encode(label).ids) for label in labels])

    #inputs = processor(text=labeltokens, images=images, return_tensors='pt', padding=True)
    #inputs = processor(text=torch.Tensor(labeltokens[:1]), images=torch.permute(images[:1], dims=(0,2,3,1)) )
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = Adam(params, lr=1e-5, weight_decay=0.0001)
    
    n=len(labels)
    assert n == images.shape[0]
    
    images_train = images[:int(n*0.9)]
    images_test = images[int(n*0.9):]
    labels_train = labels[:int(n*0.9)]
    labels_test = labels[int(n*0.9):]

    
    dataset_train = image_title_dataset(processor, images_train, labels_train, 77) #instead of 77
    loader_train = DataLoader(dataset_train, batch_size=128, shuffle=True)

    dataset_test = image_title_dataset(processor, images_test, labels_test, 77) #instead of 77
    loader_test = DataLoader(dataset_test, batch_size=128, shuffle=True)


    for epoch in range(epochs):
        loss_epoch = []
        print(f"Epoch {epoch+1}/{epochs}--------------------------------")
        for i, inputs in enumerate(loader_train):
            
        #inputs = processor(text=labels[:].tolist(), images=images[:], return_tensors="pt", padding=True)

            optimizer.zero_grad()
            outputs = model(input_ids=torch.Tensor(inputs['input_ids']).type(torch.int),
                attention_mask=torch.Tensor(inputs['attention_mask']).squeeze(1), pixel_values=torch.Tensor(inputs['pixel_values']).squeeze(1))

            logits_i = outputs.logits_per_image
            logits_t = outputs.logits_per_text
            #probs = logits.softmax(dim=1)

            labels = torch.arange(0, logits_i.shape[0])
            loss_i = criterion(logits_i, labels)
            loss_t = criterion(logits_t, labels)

            loss = (loss_i + loss_t)/2
            loss.backward()

            if device == 'cpu':
                optimizer.step()
            else:
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)

            print(f"{loss.item()}... batch {i+1}/{len(loader_train)}")
            loss_epoch.append(loss.item())
    
            if i%10==0:
                loss_epoch_test = evaluate_training(model,loader_test,criterion)
                print(f"Current eval loss after batch {i+1}/{len(loader_train)} epoch={epoch} = {loss_epoch_test}")

        loss_epoch = np.mean(loss_epoch)
        print(f"Loss epoch {epoch} = {loss_epoch}")
        
        loss_epoch_test = evaluate_training(model,loader_test,criterion)
        print(f"Final eval loss epoch={epoch} = {loss_epoch_test}")



In [8]:
#model, processor = clip.load("ViT-B/32",device=device,jit=True) #Must set jit=False for training
"""
    model : torch.nn.Module
        The CLIP model


    preprocess : Callable[[PIL.Image], torch.Tensor]
        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""

'\n    model : torch.nn.Module\n        The CLIP model\n\n\n    preprocess : Callable[[PIL.Image], torch.Tensor]\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\n'

In [9]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [10]:
nsamples = 10000
labelcol = "Protein names"
inputcol = "Sequence"
labels = pd.read_csv(r"labels.csv")[labelcol][:nsamples]
inputs = pd.read_csv(r"sequences.csv")[inputcol][:nsamples]


In [11]:
images = np.array([fcgr(seq, k=7) for seq in inputs])

In [12]:
images = np.array([img/np.sum(img) for img in images])
images = torch.Tensor(images).unsqueeze(1).repeat(1, 3, 1, 1)

In [13]:
train(model, processor, images, labels, epochs=1)

Token indices sequence length is longer than the specified maximum sequence length for this model (79 > 77). Running this sequence through the model will result in indexing errors
It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Epoch 1/1--------------------------------
5.394437789916992... batch 1/71
Current eval loss after batch 1/71 epoch=0 = 5.459358990192413
5.814339637756348... batch 2/71
5.080859661102295... batch 3/71
5.6517744064331055... batch 4/71
4.987646102905273... batch 5/71
4.921845436096191... batch 6/71
4.908542633056641... batch 7/71
4.8676652908325195... batch 8/71
4.863919734954834... batch 9/71
4.863617897033691... batch 10/71
4.8658037185668945... batch 11/71
Current eval loss after batch 11/71 epoch=0 = 4.837612330913544
4.871415138244629... batch 12/71
4.861046314239502... batch 13/71
4.862253189086914... batch 14/71
4.856645584106445... batch 15/71
4.857817649841309... batch 16/71
4.856191635131836... batch 17/71
4.855138301849365... batch 18/71
4.856529235839844... batch 19/71
4.855803489685059... batch 20/71
4.855349540710449... batch 21/71
Current eval loss after batch 21/71 epoch=0 = 4.828933000564575
4.854987144470215... batch 22/71
4.855568885803223... batch 23/71


In [None]:
import os
model_path = r"C:\Users\aapolina\CODE\sdsc_hackatchon_genAI\CLIP-DNA\from_azure"

from datetime import datetime
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")


model_name = f"model_clop_({timestamp}).ckpt"
torch.save(model.state_dict(), os.path.join(model_path, model_name) )

In [None]:
x =torch.load(os.path.join(model_path, r"model_clop_(2023_12_01-11_43_07).ckpt"))
x.keys()
