In [None]:
import yaml
import json
import plotly.express as px
import pandas as pd
import numpy as np
from collections import Counter
from pqdm.processes import pqdm
import seaborn as sns
from pathlib import Path
from pymatgen.core import Structure, PeriodicSite
from pymatgen import vis
from sklearn.model_selection import train_test_split
from glob import glob
from matplotlib import pyplot as plt
import os
import matplotlib.pyplot as plt
from pymatgen.util.coord import pbc_shortest_vectors
from matplotlib.ticker import (AutoMinorLocator, MultipleLocator, LinearLocator, IndexLocator)

%config InlineBackend.figure_format='retina'

def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)

In [None]:
structures_private = [read_pymatgen_dict(p) for p in glob('data/dichalcogenides_private/structures/*')]

In [None]:
structure_paths = os.listdir('data/dichalcogenides_public/structures/')
structures_public = [read_pymatgen_dict(f"data/dichalcogenides_public/structures/{p}") for p in structure_paths]

## Spacegroups

In [None]:
compute_spacegroups = lambda x: [s.get_space_group_info()[0] for s in x]
compute_composition = lambda x: [str(s.composition) for s in x]

sg_private = compute_spacegroups(structures_private)
sg_public = compute_spacegroups(structures_public)

cmp_private = compute_composition(structures_private)
cmp_public = compute_composition(structures_public)

In [None]:
order = Counter(sg_public)
order = list(map(lambda x: x[0], order.most_common()))

plt.figure(figsize=(10, 3))

plt.subplot(121)
sns.countplot(x=sg_public, order=order)

plt.subplot(122)
sns.countplot(x=sg_private, order=order)

In [None]:
order = Counter(cmp_public)
order = list(map(lambda x: x[0], order.most_common()))

plt.figure(figsize=(10, 3))

plt.subplot(121)
sns.countplot(x=cmp_public, order=order)
plt.xticks(rotation=45, ha='right');

plt.subplot(122)
sns.countplot(x=cmp_private, order=order)
plt.xticks(rotation=45, ha='right');

## Construst DF 

In [None]:
df = pd.DataFrame({'id': structure_paths, 's': structures_public})
df['id'] = df['id'].str.split('.').str[0]

In [None]:
target = pd.read_csv('data/dichalcogenides_public/targets.csv', names=['id', 'band_gap'], skiprows=1)

In [None]:
df = df.merge(target, on='id')

In [None]:
df = df.reset_index()

In [None]:
df['spacegroup'] = df['s'].apply(lambda x: x.get_space_group_info()[0])

In [None]:
df['composition'] = df['s'].apply(lambda x: str(x.composition))

In [None]:
df['composition'].value_counts()

In [None]:
plt.figure(figsize=(10, 3))
plt.subplot(121)

sns.stripplot(x='composition', y='band_gap', data=df, order=df['composition'].value_counts().index, size=5)
plt.xticks(rotation=45, ha='right');

plt.subplot(122)
sns.stripplot('spacegroup', 'band_gap', data=df, order=df['spacegroup'].value_counts().index, size=5)
plt.xticks(rotation=45, ha='right');

In [None]:
plt.figure(figsize=(10, 10))
plt.subplot(121)
sns.stripplot(x='composition', y='band_gap', data=df, order=df['composition'].value_counts().index, size=1)
plt.xticks(rotation=45, ha='right');
plt.gca().yaxis.set_minor_locator(IndexLocator(base=0.04, offset=0))
plt.gca().grid(which='minor', color='#CCCCCC', linestyle=':')

plt.subplot(122)
sns.stripplot(x='spacegroup', y='band_gap', data=df, order=df['spacegroup'].value_counts().index, size=1)
plt.xticks(rotation=45, ha='right');
plt.gca().yaxis.set_minor_locator(IndexLocator(base=0.04, offset=0))
plt.gca().grid(which='minor', color='#CCCCCC', linestyle=':')

In [None]:
plt.figure(figsize=(15, 6))
sns.stripplot(x='composition', y='band_gap', data=df, order=df['composition'].value_counts().index, size=1, hue='spacegroup', dodge='spacegroup')
plt.xticks(rotation=45, ha='right');
plt.gca().yaxis.set_minor_locator(IndexLocator(base=0.04, offset=0))
plt.gca().grid(which='minor', color='#CCCCCC', linestyle=':')

In [None]:
fig = px.strip(data_frame=df, x='composition', y='band_gap', hover_name='index', color='spacegroup', stripmode='overlay')


fig.update_traces(marker=dict(size=3,
                              line=dict(width=0.1,
                                        color='DarkSlateGrey')))


In [None]:
atom_names = {16: 'S', 42: 'Mo', 74: 'W', 34: 'Se'}
atom_colors = {16: '#fcba03', 74: '#1764ff', 42: '#9c2496', 34: '#22ed0c'}

def plot_legend(ax):
    for k in atom_names.keys():
        ax.scatter([], [], c=atom_colors[k], label=atom_names[k], s=125)
    ax.legend(ncol=2)

def plot_layer(ax, coord, atoms, size):
    for atom_type in np.unique(atoms):
        ind = (atoms == atom_type)
        ax.scatter(*coord[ind].T[::-1], c=atom_colors[atom_type], s=size, edgecolors='grey')
        ax.axis('off')
    ax.plot([0, 22, 22, 0, 0], [-0.5, -13.5, 14, 26, -0.5], color='k', linewidth=0.3, alpha=0.3)

def plot_structure(s, mode='mono', ax=None, legend=True):
    """
    Function to plot structures.
    
    Parameters
    ----------
    s :
        Pymatgen structure
    mode :
        'mono' to plot all three layers in one axis. In this case the top 'S' layer will be smaller, allowing to see the bottom layer as well.
        'multi' to plot layers separately.
    ax :
        Matplotlib axis to plot on. Works only in 'mono' mode.
    legend :
        Plot legend or not
    """
    
    size = [125, 125, 125]
    
    if mode == 'multi':
        fig, axs = plt.subplots(1, 3, figsize=(16, 5))
        fig.suptitle(str(s.composition))
    elif mode == 'mono':
        size[2] = 45
        if ax is None:
            fig, axs = plt.subplots(1, 1, figsize=(6, 5))
        else: 
            axs = ax
        axs.set_title(str(s.composition))
        fig = plt.gcf()
        axs = [axs, axs, axs]
        
    
    z_levels = np.array([2.154867, 3.719751, 5.284635])
    layer_indx = [np.isclose(s.cart_coords[:, 2], l, rtol=1e-04, atol=1e-04) for l in z_levels]
    
    for ind, ax, size in zip(layer_indx, axs, size):
        plot_layer(ax, s.cart_coords[ind, :2], np.array(s.atomic_numbers)[ind], size=size)
        
    if legend:
        plot_legend(ax)

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(15, 15))
for ax, indx in zip(axs.flatten(), range(10)):
    plot_structure(structures_public[indx], ax=ax)