In [1]:
import random
import itertools
import glob
import numpy as np
import sparse
from tqdm import tqdm
from PIL import Image
import torch

import matplotlib.pyplot as plt

In [2]:
def create_voxel_grid(cif_file,voxel_count, max_dims, atoms_used=['Ce','O']):
    '''
    Voxelize the a single .cif file. Automatically detects which of 2 .cif file formats is used.
    Inputs:
    cif_file - str, path to .cif file
    voxel_count - list [x,y,z] or scalar if all same, number of voxels per side.
    max_dims - list [x,y,z] of max dimensions across structures. Used to set appropriate
               grid depth in z dimension, which varies across structures.
    atoms_used - list, indicates which atoms to include 
    
    Outputs:
    grid - 3D numpy array [voxel_size_x,voxel_size_y,voxel_size_z] 
           where values indicate number of atoms centered in the voxel
    coor - numpy array [n_atoms,3] of coordinates of each atom
    atom_type numpy array [n_atoms] of atom type
    
    NOTE: coor and atom_type include all atoms, not just those in atoms_used.
    '''
    # Read the .cif file
    f = open(cif_file)
    y = f.readlines()

    # Box dimensions
    # Extract '_cell_length_c' from .cif file. This length varies for each structure.
    increm = 0
    # try-except accounts for new .cif format where cell lengths start on line 3
    try:
        x_dim = float(y[2].split('  ')[-1])
    except ValueError:
        increm = 1
        x_dim = float(y[2+increm].split('  ')[-1])
    y_dim = float(y[3+increm].split('  ')[-1])
    z_dim = float(y[4+increm].split('  ')[-1])

    # Extract the atom locations
    if len(y[15].split("  ")) < 4:

        # for .cif files from ASU_April_21
        if len(y[23].split("  ")) < 4:
            z=np.array([x.split("  ") for x in y[26:-1]])
            coor = np.array([[float(y) for y in x] for x in z[:,4:7]])
            atom_type = np.array([x.split(" ")[0] for x in z[:,1]])

        # for .cif files from models_wedge_cif and models_wedge_cif_2
        else:
            z=np.array([x.split("  ") for x in y[23:-1]])
            coor = np.array([[float(y) for y in x] for x in z[:,2:5]])
            atom_type = np.array([x.split(" ")[0] for x in z[:,1]])

    # for .cif files from models_cif
    elif len(y[15].split("  ")) == 4:
        z = np.array([x.split("  ") for x in y[15:-1]]) 
        coor = np.array([[float(y) for y in x] for x in z[:,1:]]) 
        atom_type = z[:,0]

    # Voxelize
    # .cif describes a 1x1x1 box, so voxel_count of 0.25 would create 4 voxels per dimension (4^3 total voxels)
    if type(voxel_count) == list:
        voxel_count_x,voxel_count_y,voxel_count_z = voxel_count
    elif (type(voxel_count) == int) or (type(voxel_count) == float):
        voxel_count_x = voxel_count
        voxel_count_y = voxel_count
        voxel_count_z = voxel_count
    grid = np.zeros([voxel_count_x,voxel_count_y,voxel_count_z])

    # Normalize coor for axis lengths
    coor = coor*np.array([x_dim,y_dim,z_dim])/np.array(max_dims)
    
    for i,atom in enumerate(coor):
        if atom_type[i] in atoms_used:
            x = int(np.floor(atom[0]*voxel_count_x))
            y = int(np.floor(atom[1]*voxel_count_y))
            z = int(np.floor(atom[2]*voxel_count_z))
            grid[x,y,z] += 1
    grid = sparse.COO.from_numpy(grid)
    return grid, coor, atom_type

# Determine max z length
def max_z_value(dir_header, dir_list_cif):

    cif_files = []

    for dirr in dir_list_cif:
        cif_files.extend(glob.glob(dir_header+dirr+'/Ce*.cif'))

    x_list,y_list,z_list = [],[],[]

    for cf in cif_files:
        f = open(cf)
        y = f.readlines()
        # increm accounts for new format .cif files
        increm = 0
        try:
            x_list.append(float(y[2].split('  ')[-1]))
        except ValueError:
            increm = 1
            x_list.append(float(y[2+increm].split('  ')[-1]))
        y_list.append(float(y[3+increm].split('  ')[-1]))
        z_list.append(float(y[4+increm].split('  ')[-1]))

    return [np.max(x_list),np.max(y_list),np.max(z_list)]

In [17]:
def generate_training_data(dir_header, dir_list_cif, dir_list_img, voxel_count, 
                           atoms_used, defocus_used, im_size = 256):
    '''
    Generates pairs of voxel grid (X) and image (y) as well as defocus parameter and filenames, for training
    3D grid -> image model.
    
    Inputs:
    dir_list_cif - list of directories containing .cif files, relative path from current directory
    dir_list_img - list of directories containing .yif files, relative path from current directory
    voxel_count - scalar int indicating desired voxels per grid dimension
    atoms_used - List of strings indicating which atom types (e.g. ['O', 'Ce']) to include. 
    defocus_used - List of ints indicating which defocus values to include. If set to 1, all values are used.
    im_size - 2-tuple of ints indicating desired pixels per image dimension (X,Y)
    
    Outputs:
    X_list - List of [grid, atom_type] lists for each .cif-.tif pair used. 
              grid is a [voxel_size^3] np.array with counts of atoms in each voxel
              atom_type is a np.array of strings for each atom's periodic symbol
                
    y_list - List of images, one for each entry in X_list. image is a [im_size[0],im_size[1]] np.array of pixel values.
              image created by cropping input to square then resizing to im_size in PIL.
    defocus_list - List of ints, defocus parameter for each sample 
    img_file_list - List of image filenames for each sample.
    
    '''
    # Calculate largest z value
    max_dims = max_z_value(dir_header,dir_list_cif)
    
    cif_files = []
    for dirr in dir_list_cif:
        cif_files.extend(glob.glob(dir_header+dirr+'/Ce*.cif'))


    X_list = []
    y_list = []
    defocus_list = []
    img_file_list = []

    for cif in cif_files:
        # Create voxel grid from .cif file 
        X, coor, atom_type = create_voxel_grid(cif,voxel_count, max_dims, atoms_used=atoms_used)

        # Find all associated images
        img_files = []
        [img_files.extend(glob.glob((dir_header+dirr+cif[cif.find("/",len(dir_header)+5):-4]+'*.tif').replace('[','?').replace(']','?'))) for dirr in dir_list_img]
        # find starts after the dir_header so it catches the last /
        # Filter out 'def' files
        img_files_clean = []
        for i,im_f in enumerate(img_files):
            if im_f.find("_def_") < 0:
                img_files_clean.append(im_f)
        img_files = img_files_clean        
        # Filter by Defocus value
        nm_loc = [im_f.find("nmDefocus") for im_f in img_files]
        
#         defocus_cif = [abs(int(im_f[loc-2:loc])) for im_f,loc in zip(img_files,nm_loc)]
        defocus_cif = []
        for im_f,loc in zip(img_files,nm_loc):
            if im_f[loc-2:loc-1] == '_':
                defocus_cif.append(abs(int(im_f[loc-1:loc])))
            else:
                defocus_cif.append(abs(int(im_f[loc-2:loc])))
            
    
        if defocus_used == 1:
            defocus_used = set(defocus_cif)

        #X_count counts the number of images for each .cif file, to replicate the voxel that many times
        X_count = 0
        # Collect image data
        for defoc,image in zip(defocus_cif,img_files):
            if defoc in defocus_used:
                X_count +=1
                im_data=np.array(Image.open(image).resize((im_size[0],im_size[1])))
                y_list.append(im_data)
                defocus_list.append(defoc)
                img_file_list.append(image)    

        # Make training pairs for each defocus value in defocus_used
        X_list.extend(itertools.repeat(X,X_count)) # Probably move this so it can be repeated for various defocus values
    return X_list, y_list, defocus_list, img_file_list

In [26]:
# Inputs
# dir_header = '../../em_data/'
# dir_list_cif = ['models_cif','models_wedge_cif','models_wedge_cif_2']
# dir_list_img = ['all_images','all_images_wedge','all_images_wedge_2']

dir_header = '../../em_data/ASU_April_21/'
dir_list_cif = ['10at/CIF']
dir_list_img = ['10at']

voxel_count = [84,54,98]
im_size = [84,54]
atoms_used=['Ce','O']
defocus_used = 1 # list defocus values to use, or set =1 to use all

In [27]:
X_list, y_list, defocus_list, img_file_list = generate_training_data( \
                            dir_header, dir_list_cif, dir_list_img, voxel_count, atoms_used, defocus_used, im_size)

In [30]:
import pickle

with open("X_list_10a.pkl", "wb") as fp:   #Pickling
    pickle.dump(X_list, fp)

with open("y_list_10a.pkl", "wb") as fp:   #Pickling
    pickle.dump(y_list, fp)

with open("defocus_list_10a.pkl", "wb") as fp:   #Pickling
    pickle.dump(defocus_list, fp)

with open("img_file_list_10a.pkl", "wb") as fp:   #Pickling
    pickle.dump(img_file_list, fp)

In [None]:
def last_filled_slice(grid):
    '''Function returns index of the last non-zero slice in the grid.'''
    # Sum the x-y slices
    slice_sums = grid.sum(axis=0).sum(axis=0)
    # Identify the last non-zero slice
    idx = max(index for index, item in enumerate(slice_sums) if item > 0)
    return idx