In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
!curl https://doggo.ninja/6A0kZE.zip -o cgcnn.zip
!mkdir cgcnn
!unzip -o cgcnn.zip -d cgcnn

In [None]:
!pip install pymatgen

In [None]:
import clip
from PIL import Image

import torch
import os
import shutil
import sys
import time
import warnings
from random import sample

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
import clip
import cgcnn
from cgcnn.data import CIFData
from cgcnn.data import collate_pool, get_train_val_test_loader
from cgcnn.model import CrystalGraphConvNet

torch.set_default_dtype(torch.float32)

In [None]:
clip.available_models()

In [None]:
device = "cpu"
model, preprocess = clip.load("ViT-L/14@336px", device=device)

text = clip.tokenize(["octopussy","cat","A new design strategy for high-performance organic cathode active materials for lithium-ion batteries is presented.X-ray diffraction measurements and sorption experiments demonstrated that the intercolumnar spaces in PCT-1 can incorporate various molecules accompanied by lattice expansion."]).to(device)

with torch.no_grad():
    #image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    #logits_per_image, logits_per_text = model(image, text)
    #probs = logits_per_image.softmax(dim=-1).cpu().numpy()

In [None]:
df = pd.read_csv("cifs/id_prop.csv", header=None)

In [None]:
df = df[df[0].apply(lambda x: os.path.exists(f"cifs/{x}.cif"))]

In [None]:
df.to_csv("cifs/id_prop.csv", index=False, header=None)

In [None]:
data = CIFData("cif_photocatalyst")
# run only if you have invalid CIF file errors - this will find & delete them
#data.find_errors(write=False)
s, _, _ = data[0]
cif_encoder = CrystalGraphConvNet(s[0].shape[-1], s[1].shape[-1],
                                n_conv=3,
                                n_h=2,
                                output_dim=text_features.shape[-1],
                                classification=False)

In [None]:
# test, val, train ratio is 0.1, 0.1, 0.8
train_loader, val_loader, test_loader = get_train_val_test_loader(
        train_ratio=1,
        val_ratio=0,
        test_ratio=0,
        dataset=data,
        collate_fn=collate_pool,
        batch_size=1,
        return_test=True)

optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

In [None]:
cif_encoder.load_state_dict(torch.load("checkpoints/checkpoint_cif_1.pt",  map_location=torch.device('cpu'))["model_state_dict"])

In [None]:
model.load_state_dict(torch.load("checkpoints/checkpoint_clip_1.pt",  map_location=torch.device('cpu'))["model_state_dict"])

In [None]:
import importlib
importlib.reload(cgcnn)

In [None]:
len(train_loader), len(val_loader), len(test_loader)

In [None]:
import gc
os.mkdir("checkpoints")
os.mkdir("clamp_weights")

In [None]:
device = "cuda"
import torch.nn.functional as F
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
def loss_func(feat1, feat2):
    # minimize average magnitude of cosine similarity
    logits = feat1 @ feat2.T
    feat1_similarity = feat1 @ feat1.T
    feat2_similarity = feat2 @ feat2.T
    targets = F.softmax(
            (feat1_similarity + feat2_similarity) / 2, dim=-1
        )
    feat1_loss = cross_entropy(logits, targets, reduction='none')
    feat2_loss = cross_entropy(logits.T, targets.T, reduction='none')
    loss =  (feat1_loss + feat2_loss) / 2.0
    return loss
def encode_text(targets):
    context_length = max([len(c) for c in targets])
    context_length = int(np.ceil(context_length / 77) * 77)

    tokens = clip.tokenize(targets, context_length=context_length).reshape(len(targets), -1,77).to(device)
    embeddings = []
    for sample in tokens:
        ctx = model.encode_text(sample)
        # average ctx
        ctx = torch.mean(ctx, dim=0)
        embeddings.append(ctx)
    text_embeddings = torch.stack(embeddings)
    return text_embeddings

# model = text encoder (unused image encoder)
# cif_encoder
model = model.cuda()
cif_encoder = cif_encoder.cuda()
model = model.float()
cif_encoder = cif_encoder.float()
cif_encoder.train()
model.train()
least_loss = float('inf')
least_val_loss = float('inf')
def train(epochs):
    for epoch in range(epochs):
        for batch_idx, (inputs, targets, _) in enumerate(train_loader):
            inputs = (Variable(inputs[0].cuda(non_blocking=False).float()),
                         Variable(inputs[1].cuda(non_blocking=False).float()),
                         inputs[2].cuda(non_blocking=False),
                         [crys_idx.cuda(non_blocking=False) for crys_idx in inputs[3]])
            cif_embeddings = cif_encoder(*inputs)
            text_embeddings = encode_text(targets)
            cif_embeddings = cif_embeddings / cif_embeddings.norm(dim=1, keepdim=True)
            text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
            #convert text embeddings to list
            optimizer.zero_grad()
            loss = loss_func(cif_embeddings.float().to(device), text_embeddings.float())
            loss.backward()
            optimizer.step()
            txt = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item())
            print(txt)
            if batch_idx % 15 == 0:
                scheduler.step(loss)
                # check validation loss
                val_loss = 0
                with torch.no_grad():
                    for batch_idx, (inputs, targets, _) in enumerate(val_loader):
                        inputs = (Variable(inputs[0].cuda(non_blocking=False).float()),
                                    Variable(inputs[1].cuda(non_blocking=False).float()),
                                    inputs[2].cuda(non_blocking=False),
                                    [crys_idx.cuda(non_blocking=False) for crys_idx in inputs[3]])
                        cif_embeddings = cif_encoder(*inputs)
                        text_embeddings = encode_text(targets)
                        cif_embeddings = cif_embeddings / cif_embeddings.norm(dim=1, keepdim=True)
                        text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
                        val_loss = loss_func(cif_embeddings.float().to(device), text_embeddings.float())

                # save checkpoints with loss & epoch metrics
                if least_val_loss > val_loss:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss.item(),
                        'val_loss': val_loss.item()
                        }, 'checkpoints/checkpoint_clip_val_least.pt')
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': cif_encoder.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss.item(),
                        'val_loss': val_loss.item()
                        }, 'checkpoints/checkpoint_cif_val_least.pt')
                if least_loss > loss:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss.item(),
                        'val_loss': val_loss.item()
                        }, 'checkpoints/checkpoint_clip_least.pt')
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': cif_encoder.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss.item(),
                        'val_loss': val_loss.item()
                        }, 'checkpoints/checkpoint_cif_least.pt')
                    
                txt = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tVal Loss: {:.6f}'.format(
                    epoch, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item(), val_loss.item())
                print(txt)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'val_loss': val_loss.item()
            }, 'checkpoints/checkpoint_clip_{}.pt'.format(epoch))
        torch.save({
            'epoch': epoch,
            'model_state_dict': cif_encoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'val_loss': val_loss.item()
            }, 'checkpoints/checkpoint_cif_{}.pt'.format(epoch))

train(2)

In [None]:
# save models to clamp_weights folder
torch.save(model.state_dict(), "clamp_weights/text_encoder.pt")
torch.save(cif_encoder.state_dict(), "clamp_weights/cif_encoder.pt")

In [None]:
[0,1][:1]

In [None]:
device="cpu"
model.eval()
cif_encoder.eval()
vectors = {}
for batch_idx, (inputs, targets, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
    inputs = (Variable(inputs[0].float()),
                 Variable(inputs[1].float()),
                 inputs[2],
                 [crys_idx for crys_idx in inputs[3]])
    with torch.no_grad():
        cif_embedding0 = cif_encoder(*inputs)[:1]
        cif_embedding1 = cif_encoder(*inputs)[1:]
        vectors[targets[0]] = cif_embedding0
        vectors[targets[1]] = cif_embedding1

In [None]:
import pickle

with open('text2embedding.pkl', 'wb') as handle:
    pickle.dump(vectors, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
device="cpu"
model.eval()
cif_encoder.eval()
total = 0
gone = 0
for batch_idx, (inputs, targets, _) in enumerate(train_loader):
    inputs = (Variable(inputs[0].float()),
                 Variable(inputs[1].float()),
                 inputs[2],
                 [crys_idx for crys_idx in inputs[3]])
    with torch.no_grad():
        gone+=1
        cif_embedding0 = cif_encoder(*inputs)
        cif_embedding1 = torch.cat([list(vectors.values())[0], cif_embedding0], dim=-2)
        print(cif_embedding1.shape)
        text_embeddings = encode_text(["photocatalyst methane adsorption conversion artificial photosynthesis"])
        #cif_embeddings = cif_embeddings / cif_embeddings.norm(dim=1, keepdim=True)
        text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
        loss0 = loss_func(cif_embedding1.float().to(device), text_embeddings.float())
        probs0 = loss0.softmax(dim=-1).cpu().numpy()
        #loss1 = loss_func(text_embeddings.float(), cif_embedding1.float().to(device))
        #probs1 = loss1.softmax(dim=-1).cpu().numpy()
        if np.argmax(probs0) == 1:
            total+=1
        print(probs0)
        print(f"{total}/{gone} {(total/gone)*100}%")

In [None]:
device = "cpu"
with torch.no_grad():
    text_embeddings = encode_text(["2d flat photocatalyst methane adsorption conversion artificial photosynthesis with visible light"])
    cif_embeddings = torch.cat(list(crystals.values()), dim=-2)
    cif_embeddings = cif_embeddings / cif_embeddings.norm(dim=1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
    loss = loss_func(text_embeddings.float(), cif_embeddings.float().to(device))
    probs = loss.softmax(dim=-1).cpu().numpy()
    idxs = np.argpartition(probs, -10)[-10:]
    print(np.argmin(probs))
    print(list(crystals.keys())[np.argmin(probs)])
    for idx in idxs[::-1]:
        print(probs[idx])
        print(list(crystals.keys())[idx])

In [None]:
import numpy as np
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
%matplotlib notebook
from sklearn.manifold import TSNE

pca = TSNE(n_components=2)
reduced = pca.fit_transform(torch.cat(list(vectors.values()), dim=-2))

# We need a 2 x 944 array, not 944 by 2 (all X coordinates in one list)
t = reduced.transpose()
plt.title("Final Latent Space Diagram (77 dimensional)")
for idx in range(len(t[0])):
    colort = "blue"
    label = "non-photocatalyst"
    if "photo" in list(vectors.keys())[idx]:
        colort="red"
        label="photocatalyst"
    plt.scatter(t[0][idx], t[1][idx], color=colort, alpha=0.1)
    
legend_elements = [Line2D([0], [0], marker='o', color='w', label='Photocatalyst',
                          markerfacecolor='red', markersize=10), Line2D([0], [0], marker='o', color='w', label='Non-Catalyst',
                          markerfacecolor='blue', markersize=10)]

plt.legend(handles=legend_elements)
plt.show()

In [None]:
with open("text2embedding.pkl", 'rb') as f:
    vectors = pickle.load(f)

In [None]:
import re
with open('logs/logs-new-epoch1.txt') as f:
    contents = f.read()
    matches = re.findall(r"Loss: (.*)\n", contents)
    matches = filter(lambda x: not "Val" in x, matches)

In [None]:
matches = [float(i) for i in matches]
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

plt.plot(list(range(len(matches)))[49:], moving_average(matches, 50), 'b', label = 'Validation acc')

plt.title("Loss Over Batches (2 Epochs)")
plt.xlabel("Iterations")
plt.ylabel("Cosine Similarity Loss")
plt.show()

In [None]:
from os import listdir
from os.path import isfile, join
onlyfiles = [f for f in listdir("cif_csd") if isfile(join("cif_csd", f))]

In [None]:
cifs = list(filter(lambda x: "cif" in x, onlyfiles))

In [None]:
cifs = [i.split(".cif")[0] for i in cifs]

In [None]:
df = pd.DataFrame(data={"file":cifs, "text":list(range(len(cifs)))})

In [None]:
df.to_csv("cif_csd/id_prop.csv", header=None, index=None)

In [None]:
data = CIFData("cif_csd")
train_loader, val_loader, test_loader = get_train_val_test_loader(
        train_ratio=0.14,
        val_ratio=0,
        test_ratio=0,
        dataset=data,
        collate_fn=collate_pool,
        batch_size=1,
        return_test=True)

In [None]:
from tqdm import tqdm 
crystals = {}
#data.find_errors(write=True)
for (inputs, targets, name) in tqdm(train_loader, total=len(train_loader)):
    crystals[name[0]] = cif_encoder(*inputs)