In [None]:
from mpcontribs.client import Client, Attachments
import atomai as aoi
import numpy as np
import matplotlib.pyplot as plt
import torch
from atomai.utils import graphx
%matplotlib inline

In [None]:
data_dir = "/Users/patrick/GoogleDriveLBNL/My Drive/MaterialsProject/gitrepos/mpcontribs-data/pycroscopy"
imgdata_path = f"{data_dir}/Gr_SiCr.npy"
imgdata = np.load(imgdata_path)
model_path = f"{data_dir}/G_MD.tar"
model = aoi.load_model(model_path)
# model as dict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_dict = torch.load(model_path, map_location=device)

In [None]:
figsize = (8, 8)
# plt.figure(figsize=figsize)
# plt.imshow(imgdata, cmap="gray")
# img_path = imgdata_path.replace(".npy", ".png")
# plt.savefig(img_path, bbox_inches='tight')
# # TODO add img_path as attachment

In [None]:
nn_out, coords = model.predict(imgdata)
# model.predict(imgdata, resize=(new_height, new_width))

map_dict = {0: "C", 1: "Si"}  # classes to chemical elements
px2ang = 0.104 # pixel-to-angstrom conversion
coord = coords[0] # take the first (and the only one) frame
clusters = graphx.find_cycle_clusters(coord, cycles=[5,7], map_dict=map_dict, px2ang=px2ang)
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.imshow(imgdata, cmap='gray', origin='lower')

for i, cl in enumerate(clusters):
    ax.scatter(cl[:, 1], cl[:, 0], s=2, color='red')
    xt = int(np.mean(cl[:, 1]))
    yt = int(np.mean(cl[:, 0]))
    ax.annotate(str(i+1), (xt, yt), size=10, color='white')
    
img_path_clusters = imgdata_path.replace(".npy", "_clusters.png")
plt.savefig(img_path_clusters, bbox_inches='tight')

In [None]:
clusters_mod = []
#adding a column for C atom as class 0
pad_ = 1
for i in range(len(clusters)):
    clusters[i] = np.pad(clusters[i], (0, pad_), 'constant')
    clusters[i] = clusters[i][:-1]
    clusters_mod.append(clusters[i])
    
#we can also save all the defects per image frame
defect_num = 15
coords_def_15 = {0: clusters_mod[defect_num]}
plt.scatter(coords_def_15[0][:,1], coords_def_15[0][:,0])

img_path_defects = imgdata_path.replace(".npy", "_defects.png")
plt.savefig(img_path_defects, bbox_inches='tight')

In [None]:
# client = Client()
# client.create_project(
#     name="pycroscopy",
#     title="PyCroscopy",
#     authors="A. Ghosh, S. Kalinin",
#     description="Scientific Analysis of nanoscience Data",
#     url="https://pycroscopy.github.io/pycroscopy/about.html"
# )

In [None]:
client = Client(project="pycroscopy")

In [None]:
imgdata_list = list(imgdata.tolist())
model_dict["weights"] = {
    k: v.tolist()
    for k, v in model_dict["weights"].items()
}

In [None]:
contributions = [{
    "identifier": "mp-7576", # CrSi on MP
    "data": {"clusters": len(clusters)},
    "attachments": Attachments.from_list([
        img_path_clusters, img_path_defects, #imgdata_list, model_dict,
    ])
}]

In [None]:
client.delete_contributions()
client.submit_contributions(contributions)

In [None]:
client.make_public()