In [None]:
import torch
import clip
from PIL import Image
import os
import shutil
import sys
import time
import warnings
from random import sample

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR

import cgcnn
from cgcnn.data import CIFData
from cgcnn.data import collate_pool, get_train_val_test_loader
from cgcnn.model import CrystalGraphConvNet

torch.set_default_dtype(torch.float32)

In [None]:
clip.available_models()

In [None]:
device = "mps"
model, preprocess = clip.load("RN50", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["octopussy","cat","A new design strategy for high-performance organic cathode active materials for lithium-ion batteries is presented.X-ray diffraction measurements and sorption experiments demonstrated that the intercolumnar spaces in PCT-1 can incorporate various molecules accompanied by lattice expansion."]).to(device)

with torch.no_grad():
    #image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    #logits_per_image, logits_per_text = model(image, text)
    #probs = logits_per_image.softmax(dim=-1).cpu().numpy()

In [None]:
data.find_errors()

In [None]:
data = CIFData("./cifs/")
# run only if you have invalid CIF file errors - this will find & delete them
# data.find_errors()
s, _, _ = data[0]
cif_encoder = CrystalGraphConvNet(s[0].shape[-1], s[1].shape[-1],
                                n_conv=3,
                                n_h=2,
                                output_dim=text_features.shape[-1],
                                classification=False)

In [None]:
text = [
  """An octopus (pl: octopuses or octopodes, see below for variants) is a soft-bodied, eight-limbed mollusc of the order Octopoda (/ɒkˈtɒpədə/, ok-TOP-ə-də[3]). The order consists of some 300 species and is grouped within the class Cephalopoda with squids, cuttlefish, and nautiloids. Like other cephalopods, an octopus is bilaterally symmetric with two eyes and a beaked mouth at the center point of the eight limbs.[a] The soft body can radically alter its shape, enabling octopuses to squeeze through small gaps. They trail their eight appendages behind them as they swim. The siphon is used both for respiration and for locomotion, by expelling a jet of water. Octopuses have a complex nervous system and excellent sight, and are among the most intelligent and behaviourally diverse of all invertebrates. Octopuses inhabit various regions of the ocean, including coral reefs, pelagic waters, and the seabed; some live in the intertidal zone and others at abyssal depths. Most species grow quickly, mature early, and are short-lived. In most species, the male uses a specially adapted arm to deliver a bundle of sperm directly into the female's mantle cavity, after which he becomes senescent and dies, while the female deposits fertilised eggs in a den and cares for them until they hatch, after which she also dies. Strategies to defend themselves against predators include the expulsion of ink, the use of camouflage and threat displays, the ability to jet quickly through the water and hide, and even deceit. All octopuses are venomous, but only the blue-ringed octopuses are known to be deadly to humans.""",
  'The title complex was synthesized in 41.6% yield by reactions between Os3(CO)11(CH3CN) and 2,4,6-tri­methyl­hexa­hydro-1,3,5-triazine.Each Os atom exhibits a pseudo-octa­hedral coordination environment, discounting the bridging Os—Os bond.',
  'The molecular salt, C23H26N2O2+Cl, was obtained from 1-isobutyl-8,9-dimeth­oxy-3-phenyl-5,6-di­hydro­imidazo[5,1-a]iso­quinoline.In the crystal structure, centrosymmetric dimers are formed by N—HCl and C—HCl hydrogen bonds.',
  'The title compound, C16H20N4, was synthesized by cyanation of brom­hexine.The substituted aniline and cyclo­hexane rings are inclined to one another by 37.26 (6)in one mol­ecule and by 22.84 (7)in the other.',
  'Your purchase has been completed.Your documents are now available to view.Your purchase has been completed.Your documents are now available to view.',
  'Monomeric boroles have been gaining attention as reagents for the synthesis of heterocycles due to their ability to insert atoms into the BC4 ring in a single step.This work demonstrates that insertion chemistry is possible with Diels–Alder dimeric boroles.',
  'Deep-blue thermally activated delayed fluorescence (TADF) emitters are promising alternatives for conventional fluorescence and phosphorescence materials.Four new donor–acceptor (D–A)-type TADF molecules incorporating phenazasiline, phenazagermine, and tetramethylcarbazole as weak D units were designed and synthesized.Photophysical investigation revealed that phenazasiline and phenazagermine-based emitters concurrently exhibit blue TADF emissions.',
  'Silyl, silylene and silene complexes were accessed via reactions of [(dmpe)2MnH(C2H4)] (1) with hydrosilanes, in some cases followed by ethylene.'
]
# get longest text in batch
context_length = max([len(c) for c in text])
context_length = int(np.ceil(context_length / 77) * 77)

tokens = clip.tokenize(text, context_length=context_length).reshape(len(text), -1,77).to(device)
#print(tokens.shape)

embeddings = []
# run through clip
with torch.no_grad():
  for sample in tokens:
      ctx = model.encode_text(sample)
      # average ctx
      ctx = torch.mean(ctx, dim=0)
      embeddings.append(ctx)
  embeddings = torch.stack(embeddings)
  image_features = image_features / image_features.norm(dim=1, keepdim=True)
  embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)

  # cosine similarity as logits
  logit_scale = model.logit_scale.exp()
  logits_per_image = logit_scale * image_features @ embeddings.t()
  probs = logits_per_image.softmax(dim=-1).cpu().numpy()
  for i in probs[0]:
    print(i)

In [None]:
x = [1,2,3,4]
x.append(5)
x

In [None]:
# test, val, train ratio is 0.1, 0.1, 0.8
train_loader, val_loader, test_loader = get_train_val_test_loader(
        train_ratio=0.8,
        val_ratio=0.1,
        test_ratio=0.1,
        dataset=data,
        collate_fn=collate_pool,
        batch_size=1,
        return_test=True)

optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

In [None]:
import importlib
importlib.reload(cgcnn)

In [None]:
len(train_loader), len(val_loader), len(test_loader)

In [None]:
import os

def notify(title, text):
    os.system("""
              osascript -e 'display notification "{}" with title "{}"'
              """.format(text, title))

In [None]:
device = "cpu"
cif_encoder = cif_encoder.float()
for batch_idx, (inputs, targets, _) in enumerate(train_loader):
    inputs = (Variable(inputs[0].to(device, non_blocking=True)),
                Variable(inputs[1].to(device, non_blocking=True)),
                inputs[2].to(device, non_blocking=True),
                [crys_idx.to(device, non_blocking=True) for crys_idx in inputs[3]])
    device = "mps"
    text_embeddings = encode_text(["cif ashkjdahsd kajshd kasjhd coo"])
    print(targets)
    cif_embeddings = cif_encoder(*inputs)
    cif_embeddings = cif_embeddings / cif_embeddings.norm(dim=1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
    print(loss_func(cif_embeddings.float().to(device), text_embeddings.float()))
    break

In [None]:
encode_text(["cicosjao"])

In [None]:
batch_size = 4
dim = 256
embeddings = torch.randn(2, 5)
out = embeddings @ embeddings.T
print(F.softmax(out, dim=-1))

In [None]:
F.normalize(torch.tensor([[1.,2.,3.,4.]])) @ F.normalize(torch.tensor([[1.,2.,3.,4.]])).T

In [None]:
device = "mps"
import torch.nn.functional as F
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
def loss_func(feat1, feat2):
    # minimize average magnitude of cosine similarity
    logits = feat1 @ feat2.T
    feat1_similarity = feat1 @ feat1.T
    feat2_similarity = feat2 @ feat2.T
    targets = F.softmax(
            (feat1_similarity + feat2_similarity) / 2, dim=-1
        )
    print(targets)
    feat1_loss = cross_entropy(logits, targets, reduction='none')
    feat2_loss = cross_entropy(logits.T, targets.T, reduction='none')
    loss =  (feat1_loss + feat2_loss) / 2.0
    return loss.mean()
loss_func(F.normalize(torch.tensor([ [0., 0.], [1.,1.]])),F.normalize(torch.tensor([[0., 0.], [1., 1.]])))

In [None]:
device = "mps"
import torch.nn.functional as F
notify('Training', 'Training started')
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
def loss_func(feat1, feat2):
    # minimize average magnitude of cosine similarity
    logits = feat1 @ feat2.T
    feat1_similarity = feat1 @ feat1.T
    feat2_similarity = feat2 @ feat2.T
    targets = F.softmax(
            (images_similarity + texts_similarity) / 2, dim=-1
        )
    feat1_loss = cross_entropy(logits, targets, reduction='none')
    feat2_loss = cross_entropy(logits.T, targets.T, reduction='none')
    loss =  (feat1_loss + feat2_loss) / 2.0
    return loss.mean()
def encode_text(targets):
    context_length = max([len(c) for c in targets])
    context_length = int(np.ceil(context_length / 77) * 77)

    tokens = clip.tokenize(targets, context_length=context_length).reshape(len(targets), -1,77).to(device)
    embeddings = []
    for sample in tokens:
        ctx = model.encode_text(sample)
        # average ctx
        ctx = torch.mean(ctx, dim=0)
        embeddings.append(ctx)
    text_embeddings = torch.stack(embeddings)
    return text_embeddings

# model = text encoder (unused image encoder)
# cif_encoder
model = model.float()
cif_encoder = cif_encoder.float()
cif_encoder.train()
model.train()
def train(epochs):
    for epoch in range(epochs):
        scheduler.step()
        for batch_idx, (inputs, targets, _) in enumerate(train_loader):
            inputs = (Variable(inputs[0].float()),
                         Variable(inputs[1].float()),
                         inputs[2],
                         [crys_idx for crys_idx in inputs[3]])
            cif_embeddings = cif_encoder(*inputs)
            text_embeddings = encode_text(targets)
            cif_embeddings = cif_embeddings / cif_embeddings.norm(dim=1, keepdim=True)
            text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
            #convert text embeddings to list
            optimizer.zero_grad()
            loss = loss_func(cif_embeddings.float().to(device), text_embeddings.float())
            loss.backward()
            optimizer.step()
            if batch_idx % 1 == 0:
                # check validation loss
                val_loss = 0
                with torch.no_grad():
                    for batch_idx, (inputs, targets, _) in enumerate(val_loader):
                        inputs = (Variable(inputs[0].float()),
                                    Variable(inputs[1].float()),
                                    inputs[2],
                                    [crys_idx for crys_idx in inputs[3]])
                        cif_embeddings = cif_encoder(*inputs)
                        text_embeddings = encode_text(targets)
                        cif_embeddings = cif_embeddings / cif_embeddings.norm(dim=1, keepdim=True)
                        text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
                        val_loss = loss_func(cif_embeddings.float().to(device), text_embeddings.float())

                # save checkpoints with loss & epoch metrics
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss.item(),
                    'val_loss': val_loss.item()
                    }, 'checkpoints/checkpoint_{}.pt'.format(epoch))
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': cif_encoder.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss.item(),
                    'val_loss': val_loss.item()
                    }, 'checkpoints/checkpoint_cif_{}.pt'.format(epoch))
                    
                txt = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tVal Loss: {:.6f}'.format(
                    epoch, batch_idx*len(targets), len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item(), val_loss.item())
                print(txt)
                notify(f"Training Epoch {epoch}", txt)
train(10)

In [None]:
# save models to clamp_weights folder
torch.save(model.state_dict(), "clamp_weights/text_encoder.pt")
torch.save(cif_encoder.state_dict(), "clamp_weights/cif_encoder.pt")

In [None]:
cif_encoder

In [None]:
cif_encoder.load_state_dict(torch.load("./checkpoints/checkpoint_cif_val_least.pt", map_location=torch.device('cpu'))["model_state_dict"])

In [None]:
model.load_state_dict(torch.load("./checkpoints/checkpoint_clip_val_least.pt", map_location=torch.device('cpu'))["model_state_dict"])