In [87]:
import numpy as np
import cctbx
import torch
import pandas as pd
import scipy
import iotbx.cif
cif_file = 'BrCl.cif'
reader = iotbx.cif.reader(file_path=cif_file)#.build_crystal_structures()

In [13]:
import numpy as np
import torch
import torch.nn as nn
from iotbx import cif
from models.model import MiniUnet
from models.xrd_transformer import XRDTransformer
from torchmetrics import MeanSquaredError
from pytorch_msssim import ssim

# Configuration
model_type = 'transformer'  # Options: 'unet', 'transformer'

# Load CIF file
cif_file = 'BrCl.cif'
reader = cif.reader(file_path=cif_file)
strucs = reader.build_crystal_structures()

# Load model based on configuration
if model_type == 'unet':
    model = MiniUnet()
    state_dict = {}
    state_old = torch.load("Experiments/unet_3layers/14.09_clin_doob/model.ckpt")['state_dict']
    for key in state_old.keys():
        key_new = key[6:]
        state_dict[key_new] = state_old[key]
    model.load_state_dict(state_dict, strict=True)
    print('UNet model loaded successfully')
elif model_type == 'transformer':
    model = XRDTransformer(
        input_shape=(26, 18, 23),  # Adjust based on your data
        embed_dim=128,
        depth=5,
        num_heads=4,
        mlp_ratio=4,
        drop_rate=0.1,
        attn_drop_rate=0.1,
        embedding_type='onehot'
    )
    state_dict = {}
    state_old = torch.load("Experiments/XRDTransformer/try_75k/model.ckpt")['state_dict']
    for key in state_old.keys():
        key_new = key[6:]
        state_dict[key_new] = state_old[key]
    model.load_state_dict(state_dict, strict=True)
    print('Transformer model loaded successfully')

# Define Laue types for monoclinic structures
laue_types = {'clin': {'h': [-13, 12], 'k': [0, 17], 'l': [0, 22]}}

# Initialize metrics
mse_list = []
r_factor_list = []
ssim_list = []
print(len(strucs))
# Process each structure
for key, struc in strucs.items():
    # Check if the structure is monoclinic
    #print(struc.crystal_symmetry().space_group().crystal_system())
    if struc.crystal_symmetry().space_group().crystal_system() == 'Monoclinic':
        print(f"Processing monoclinic structure: {key}")

        # Calculate structure factors
        a_high = struc.structure_factors(d_min=0.8).f_calc().sort()
        I_high = a_high.as_intensity_array().data().as_numpy_array()
        a_low = struc.structure_factors(d_min=1.5).f_calc().sort()

        # Get indices
        ind_high = a_high.indices()
        ind_low = a_low.indices()
        intensity = I_high

        # Create index mappings
        hkl_minmax = laue_types['clin']
        dics = {'h': {}, 'k': {}, 'l': {}}
        dics__ = {'h': {}, 'k': {}, 'l': {}}
        for letter in 'hkl':
            for i in range(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1):
                dics[letter][hkl_minmax[letter][0] + i] = i
                dics__[letter][i] = hkl_minmax[letter][0] + i
        h2ind, k2ind, l2ind = dics['h'], dics['k'], dics['l']
        ind2h, ind2k, ind2l = dics__['h'], dics__['k'], dics__['l']

        # Create low and high resolution tensors
        num = 0
        size = [1, 1]
        for letter in 'hkl':
            size.append(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1)
        low, high = np.zeros(size), np.zeros(size)
        for j, ind in enumerate(ind_high):
            h, k, l = ind
            if h not in h2ind.keys() or k not in k2ind.keys() or l not in l2ind.keys():
                num += 1
            else:
                high[0, 0, h2ind[h], k2ind[k], l2ind[l]] = intensity[j]
        for ind in ind_low:
            h, k, l = ind
            if h in h2ind.keys() and k in k2ind.keys() and l in l2ind.keys():
                low[0, 0, h2ind[h], k2ind[k], l2ind[l]] = high[0, 0, h2ind[h], k2ind[k], l2ind[l]]
        print(f'{num} / {len(ind_high)} didnt fit')

        # Normalize and convert to tensors
        low, high = np.sqrt(low), np.sqrt(high)
        factor = low.max()
        high /= factor
        low /= factor
        low = torch.from_numpy(low).float()
        high = torch.from_numpy(high).float()

        # Reconstruct using the model
        recon = nn.ReLU()(model(low)).detach()

        # Apply mask and compare
        mask = low == 0
        recon = recon * mask + ~mask*low

        # Calculate Mean Squared Error
        criterion = MeanSquaredError()
        mse = criterion(recon, high).detach().item()
        mse_list.append(mse)

        # Calculate R-factor
        r_factor = torch.mean(torch.sum(torch.abs(torch.abs(recon)-torch.abs(high)), axis = (1, 2, 3, 4))/torch.sum(torch.abs(high), axis = (1, 2, 3, 4)))
        r_factor_list.append(r_factor.item())

        # Calculate SSIM 
        recon_np = recon.numpy()
        high_np = high.numpy()
        ssim_value = ssim(recon, high, data_range=high.max() - high.min(), size_average=True)
        ssim_list.append(ssim_value.item())

        print(f'MSE for {key}: {mse}, R-factor: {r_factor.item()}, SSIM: {ssim_value.item()}')

# Calculate mean metrics
mean_mse = np.mean(mse_list)
mean_r_factor = np.mean(r_factor_list)
mean_ssim = np.mean(ssim_list)

print(f'Mean MSE: {mean_mse}, Mean R-factor: {mean_r_factor}, Mean SSIM: {mean_ssim}')
print(mean_mse.shape)

Transformer model loaded successfully
2001
Processing monoclinic structure: CSD_CIF_ABAWIL
105 / 1549 didnt fit




MSE for CSD_CIF_ABAWIL: 0.006347080692648888, R-factor: 0.7872632741928101, SSIM: 0.7260646820068359
Processing monoclinic structure: CSD_CIF_ABICIB
22 / 2346 didnt fit
MSE for CSD_CIF_ABICIB: 0.0012669195421040058, R-factor: 0.7783581018447876, SSIM: 0.6638015508651733
Processing monoclinic structure: CSD_CIF_ABICOH
33 / 2433 didnt fit
MSE for CSD_CIF_ABICOH: 0.002825310220941901, R-factor: 0.7574260234832764, SSIM: 0.6071116924285889
Processing monoclinic structure: CSD_CIF_ABOWAR
482 / 1542 didnt fit
MSE for CSD_CIF_ABOWAR: 0.004000211134552956, R-factor: 0.6627606153488159, SSIM: 0.7124351263046265
Processing monoclinic structure: CSD_CIF_ABRBPH
609 / 2340 didnt fit
MSE for CSD_CIF_ABRBPH: 0.002752214903011918, R-factor: 0.8415480256080627, SSIM: 0.41940703988075256
Processing monoclinic structure: CSD_CIF_ACEKIC
734 / 4188 didnt fit
MSE for CSD_CIF_ACEKIC: 0.00458368519321084, R-factor: 0.810329020023346, SSIM: 0.28259018063545227
Processing monoclinic structure: CSD_CIF_ACEVUA
94

In [15]:
len(mse_list)

1182

In [16]:
import numpy as np
import torch
import torch.nn as nn
from iotbx import cif
from models.model import MiniUnet
from models.xrd_transformer import XRDTransformer
from torchmetrics import MeanSquaredError
from pytorch_msssim import ssim

# Configuration
model_type = 'transformer'  # Options: 'unet', 'transformer'

# Load CIF file
cif_file = 'BrCl.cif'
reader = cif.reader(file_path=cif_file)
strucs = reader.build_crystal_structures()

# Load model based on configuration
if model_type == 'unet':
    model = MiniUnet()
    state_dict = {}
    state_old = torch.load("Experiments/unet_3layers/14.09_clin_doob/model.ckpt")['state_dict']
    for key in state_old.keys():
        key_new = key[6:]
        state_dict[key_new] = state_old[key]
    model.load_state_dict(state_dict, strict=True)
    print('UNet model loaded successfully')
elif model_type == 'transformer':
    model = XRDTransformer(
        input_shape=(26, 18, 23),  # Adjust based on your data
        embed_dim=128,
        depth=5,
        num_heads=4,
        mlp_ratio=4,
        drop_rate=0.1,
        attn_drop_rate=0.1,
        embedding_type='onehot'
    )
    state_dict = {}
    state_old = torch.load("Experiments/XRDTransformer/try_75k/model.ckpt")['state_dict']
    for key in state_old.keys():
        key_new = key[6:]
        state_dict[key_new] = state_old[key]
    model.load_state_dict(state_dict, strict=True)
    print('Transformer model loaded successfully')

# Define Laue types for monoclinic structures
laue_types = {'clin': {'h': [-13, 12], 'k': [0, 17], 'l': [0, 22]}}

# Initialize metrics
mse_list = []
r_factor_list = []
ssim_list = []
print(len(strucs))
# Process each structure
for key, struc in strucs.items():
    # Check if the structure is monoclinic
    #print(struc.crystal_symmetry().space_group().crystal_system())
    if struc.crystal_symmetry().space_group().crystal_system() == 'Monoclinic':
        print(f"Processing monoclinic structure: {key}")

        # Calculate structure factors
        a_high = struc.structure_factors(d_min=0.8).f_calc().sort()
        I_high = a_high.as_intensity_array().data().as_numpy_array()
        a_low = struc.structure_factors(d_min=1.5).f_calc().sort()

        # Get indices
        ind_high = a_high.indices()
        ind_low = a_low.indices()
        intensity = I_high

        # Create index mappings
        hkl_minmax = laue_types['clin']
        dics = {'h': {}, 'k': {}, 'l': {}}
        dics__ = {'h': {}, 'k': {}, 'l': {}}
        for letter in 'hkl':
            for i in range(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1):
                dics[letter][hkl_minmax[letter][0] + i] = i
                dics__[letter][i] = hkl_minmax[letter][0] + i
        h2ind, k2ind, l2ind = dics['h'], dics['k'], dics['l']
        ind2h, ind2k, ind2l = dics__['h'], dics__['k'], dics__['l']

        # Create low and high resolution tensors
        num = 0
        size = [1, 1]
        for letter in 'hkl':
            size.append(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1)
        low, high = np.zeros(size), np.zeros(size)
        for j, ind in enumerate(ind_high):
            h, k, l = ind
            if h not in h2ind.keys() or k not in k2ind.keys() or l not in l2ind.keys():
                num += 1
            else:
                high[0, 0, h2ind[h], k2ind[k], l2ind[l]] = intensity[j]
        for ind in ind_low:
            h, k, l = ind
            if h in h2ind.keys() and k in k2ind.keys() and l in l2ind.keys():
                low[0, 0, h2ind[h], k2ind[k], l2ind[l]] = high[0, 0, h2ind[h], k2ind[k], l2ind[l]]
        print(f'{num} / {len(ind_high)} didnt fit')

        # Normalize and convert to tensors
        low, high = np.sqrt(low), np.sqrt(high)
        factor = low.max()
        high /= factor
        low /= factor
        low = torch.from_numpy(low).float()
        high = torch.from_numpy(high).float()

        # Reconstruct using the model
        recon = nn.ReLU()(model(low)).detach()

        # Apply mask and compare
        mask = low == 0
        recon = recon * mask + low#~mask*low

        # Calculate Mean Squared Error
        criterion = MeanSquaredError()
        mse = criterion(recon, high).detach().item()
        mse_list.append(mse)

        # Calculate R-factor
        r_factor = torch.mean(torch.sum(torch.abs(torch.abs(recon)-torch.abs(high)), axis = (1, 2, 3, 4))/torch.sum(torch.abs(high), axis = (1, 2, 3, 4)))
        r_factor_list.append(r_factor.item())

        # Calculate SSIM 
        recon_np = recon.numpy()
        high_np = high.numpy()
        ssim_value = ssim(recon, high, data_range=high.max() - high.min(), size_average=True)
        ssim_list.append(ssim_value.item())

        print(f'MSE for {key}: {mse}, R-factor: {r_factor.item()}, SSIM: {ssim_value.item()}')

# Calculate mean metrics
mean_mse = np.mean(mse_list)
mean_r_factor = np.mean(r_factor_list)
mean_ssim = np.mean(ssim_list)

print(f'Mean MSE: {mean_mse}, Mean R-factor: {mean_r_factor}, Mean SSIM: {mean_ssim}')
print(len(mse_list))

Transformer model loaded successfully
2001
Processing monoclinic structure: CSD_CIF_ABAWIL
105 / 1549 didnt fit




MSE for CSD_CIF_ABAWIL: 0.006368920672684908, R-factor: 0.7781714200973511, SSIM: 0.7316609621047974
Processing monoclinic structure: CSD_CIF_ABICIB
22 / 2346 didnt fit
MSE for CSD_CIF_ABICIB: 0.0012608233373612165, R-factor: 0.7797011733055115, SSIM: 0.6646088361740112
Processing monoclinic structure: CSD_CIF_ABICOH
33 / 2433 didnt fit
MSE for CSD_CIF_ABICOH: 0.0028290869668126106, R-factor: 0.7551823854446411, SSIM: 0.6081129908561707
Processing monoclinic structure: CSD_CIF_ABOWAR
482 / 1542 didnt fit
MSE for CSD_CIF_ABOWAR: 0.003980741370469332, R-factor: 0.6627542972564697, SSIM: 0.715944230556488
Processing monoclinic structure: CSD_CIF_ABRBPH
609 / 2340 didnt fit
MSE for CSD_CIF_ABRBPH: 0.0027609553653746843, R-factor: 0.8284634947776794, SSIM: 0.41374966502189636
Processing monoclinic structure: CSD_CIF_ACEKIC
734 / 4188 didnt fit
MSE for CSD_CIF_ACEKIC: 0.00451062573119998, R-factor: 0.8053115606307983, SSIM: 0.2900439500808716
Processing monoclinic structure: CSD_CIF_ACEVUA
9

In [6]:
torch.utils.checkpoint.checkpoint(torch.ones(1))



TypeError: 'Tensor' object is not callable

In [88]:
strucs = reader.build_crystal_structures()

In [89]:
strucs.keys()

odict_keys(['CSD_CIF_ABAFEQ', 'CSD_CIF_ABAWIL', 'CSD_CIF_ABENAZ', 'CSD_CIF_ABENED', 'CSD_CIF_ABICIB', 'CSD_CIF_ABICOH', 'CSD_CIF_ABORUG', 'CSD_CIF_ABOSAN', 'CSD_CIF_ABOWAR', 'CSD_CIF_ABRBPH', 'CSD_CIF_ACBTHO', 'CSD_CIF_ACEKIC', 'CSD_CIF_ACEVUA', 'CSD_CIF_ACEZIS', 'CSD_CIF_ACOKAE', 'CSD_CIF_ACOKEI', 'CSD_CIF_ACOPUF', 'CSD_CIF_ACOQAM', 'CSD_CIF_ADALUN', 'CSD_CIF_ADAPAZ', 'CSD_CIF_ADAPED', 'CSD_CIF_ADEKAW01', 'CSD_CIF_ADIMEH', 'CSD_CIF_AFEQIM', 'CSD_CIF_AFETIP', 'CSD_CIF_AFOFOS', 'CSD_CIF_AFOTIA', 'CSD_CIF_AFOTOG', 'CSD_CIF_AFUKOD', 'CSD_CIF_AFUVUU', 'CSD_CIF_AGARAD', 'CSD_CIF_AGATIN', 'CSD_CIF_AGATIN01', 'CSD_CIF_AHIFOP', 'CSD_CIF_AJETOZ', 'CSD_CIF_AJEZOF', 'CSD_CIF_AJITOD', 'CSD_CIF_AJITUJ', 'CSD_CIF_AJIWIC', 'CSD_CIF_AJIWOI', 'CSD_CIF_AJUMAV', 'CSD_CIF_AJUMEZ', 'CSD_CIF_AJUNAU', 'CSD_CIF_AJUQEE', 'CSD_CIF_AJUQII', 'CSD_CIF_AKENOU', 'CSD_CIF_AMATAK', 'CSD_CIF_AMCLPY11', 'CSD_CIF_AMILIU', 'CSD_CIF_AMILOA', 'CSD_CIF_ANADAX', 'CSD_CIF_ANADEB', 'CSD_CIF_ANIPOE', 'CSD_CIF_APAJOS', 'CSD_CIF_A

In [90]:
struc = strucs['CSD_CIF_ZUSLIL']

In [91]:
struc.show_summary()

Number of scatterers: 25
At special positions: 0
Unit cell: (12.5894, 9.5051, 9.7495, 90, 111.933, 90)
Space group: P 1 21/c 1 (No. 14)


  xray.structure(
    crystal_symmetry=crystal.symmetry(
      unit_cell=(12.5894, 9.5051, 9.7495, 90, 111.933, 90),
      space_group_symbol="P 1 21/c 1"
    ),
    scatterers=flex.xray_scatterer([
      xray.scatterer( #0
        label="Cl1",
        site=(0.029350, 0.270840, 0.040160),
        u=0.000000),
      xray.scatterer( #1
        label="N1",
        site=(0.503200, 0.261000, 0.586500),
        u=0.000000),
      xray.scatterer( #2
        label="N2",
        site=(0.480070, 0.181100, 0.362700),
        u=0.000000),
      xray.scatterer( #3
        label="C1",
        site=(0.434300, 0.255100, 0.441200),
        u=0.000000),
      xray.scatterer( #4
        label="C2",
        site=(0.584400, 0.135100, 0.463900),
        u=0.000000),
      xray.scatterer( #5
        label="C3",
        site=(0.600000, 0.184700, 0.604500),
        u=0.000000),
      xray.scatterer( #6
        label="C4",
        site=(0.239700, 0.274600, 0.253300),
        u=0.000000),
      xray.scatterer( #

In [92]:
a_high = struc.structure_factors(d_min= 0.8).f_calc().sort()
I_high = a_high.as_intensity_array().data().as_numpy_array()
a_low = struc.structure_factors(d_min= 1.5).f_calc().sort()

In [93]:
ind_high = a_high.indices()
ind_low = a_low.indices()
intensity = I_high

In [94]:
laue_types = {'romb': {'h': [0, 16], 'k': [0, 21], 'l': [0, 28]}, 'clin': {'h': [-13, 12], 'k': [0, 17], 'l': [0, 22]}, 'all': {'h': [-16, 16], 'k': [-14, 21], 'l': [0, 28]}}
hkl_minmax = laue_types['clin']
dics = {'h': {}, 'k': {}, 'l': {}}
dics__ = {'h': {}, 'k': {}, 'l': {}}
for letter in 'hkl':
    for i in range(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1):
        dics[letter][hkl_minmax[letter][0] + i] = i
        dics__[letter][i] = hkl_minmax[letter][0] + i
h2ind, k2ind, l2ind = dics['h'], dics['k'], dics['l']
ind2h, ind2k, ind2l = dics__['h'], dics__['k'], dics__['l']

In [95]:
ind2h

{0: -13,
 1: -12,
 2: -11,
 3: -10,
 4: -9,
 5: -8,
 6: -7,
 7: -6,
 8: -5,
 9: -4,
 10: -3,
 11: -2,
 12: -1,
 13: 0,
 14: 1,
 15: 2,
 16: 3,
 17: 4,
 18: 5,
 19: 6,
 20: 7,
 21: 8,
 22: 9,
 23: 10,
 24: 11,
 25: 12}

In [96]:
num = 0
size = [1, 1]
for letter in 'hkl':
    size.append(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1)
low, high = np.zeros(size), np.zeros(size)
for j, ind in enumerate(ind_high):
    h, k, l = ind
    if h not in h2ind.keys() or k not in k2ind.keys() or l not in l2ind.keys():
        #print(ind)
        num += 1
    else: high[0, 0, h2ind[h], k2ind[k], l2ind[l]] = intensity[j]
for ind in ind_low:
    h, k, l = ind
    low[0, 0, h2ind[h], k2ind[k], l2ind[l]] = high[0, 0, h2ind[h], k2ind[k], l2ind[l]]
print(f'{num} / {len(ind_high)} didnt fit')

77 / 2208 didnt fit


In [97]:
low[0, 0, h2ind[-3], k2ind[0], l2ind[2]]

10478.054490274111

In [98]:
low[0, 0, h2ind[0], k2ind[2], l2ind[2]]

720.0053090711083

In [18]:
low, high = np.sqrt(low), np.sqrt(high)

In [19]:
low[0, 0, h2ind[0], k2ind[2], l2ind[2]]

26.832914658514238

In [20]:
factor = low.max()
high /= factor
low /= factor
low = torch.from_numpy(low).float()
high = torch.from_numpy(high).float()

In [21]:
low[0, 0, h2ind[0], k2ind[2], l2ind[2]]

tensor(0.1510)

In [22]:
from models.model import MiniUnet, UpBlock, DownBlock, DoubleConv
import torch
import torch.nn as nn
model = MiniUnet()
state_dict = {}
state_old = torch.load("Experiments/unet_3layers/14.09_clin_doob/model.ckpt")['state_dict']
for key in state_old.keys():
    key_new = key[6:]#.lstrip('model.')
    state_dict[key_new] = state_old[key]
model.load_state_dict(state_dict, strict=True)
print('Loaded successfully')

Loaded successfully


In [23]:
recon = nn.ReLU()(model(low)).detach()

In [24]:
recon[0, 0, h2ind[0], k2ind[2], l2ind[2]]

tensor(0.1402)

In [26]:
mask = low == 0
mask[0, 0, h2ind[0], k2ind[2], l2ind[2]]

tensor(False)

In [27]:
mask = low == 0
recon = recon * mask + low

In [28]:
recon[0, 0, h2ind[0], k2ind[2], l2ind[2]]

tensor(0.1510)

In [29]:
from torchmetrics import MeanSquaredError, MeanAbsoluteError
criterion = MeanSquaredError()

In [14]:
criterion(recon, high).detach()

tensor(0.0013)

In [63]:
recon *= factor
high *= factor
criterion(recon, high).detach()

tensor(41.2927)

In [64]:
recon = recon ** 2
high = high ** 2
criterion(recon, high).detach()

tensor(80412.4844)

In [67]:
print(recon[0, 0, 3, 3, 3], high[0, 0, 3, 3, 3])

tensor(81.8371, grad_fn=<SelectBackward0>) tensor(323.0356)


In [41]:
mask = low==0
recon = recon*mask + low

In [42]:
criterion(recon, high).detach()

tensor(0.0128)

In [32]:
recon *= factor
high *= factor
criterion(recon, high).detach()

tensor(39.9779)

In [33]:
recon[0, 0, h2ind[0], k2ind[2], l2ind[2]]

tensor(26.8329)

In [34]:
recon.max()

tensor(177.6708)

In [35]:
factor

177.67078344768026

In [36]:
recon = recon ** 2
high = high ** 2
criterion(recon, high).detach()

tensor(78132.0469)

In [37]:
recon[0, 0, h2ind[0], k2ind[2], l2ind[2]]

tensor(720.0053)

In [109]:
print(recon[0, 0, 3, 3, 3], high[0, 0, 3, 3, 3])

tensor(81.8371) tensor(323.0356)


In [30]:
import cctbx
from cctbx import miller
from cctbx import crystal
from cctbx.array_family import flex

group = 'P21/c'
ms = miller.build_set(
                crystal_symmetry=crystal.symmetry(
                space_group_symbol=group,
                unit_cell=(30,30,30,90,90,90)),
                anomalous_flag=False, d_min = 0.8
            )
ms_base = ms.customized_copy(
                space_group_info = ms.space_group().build_derived_point_group().info())
ms_all = ms_base.complete_set()
sys_abs = ms_all.lone_set(other=ms_base)
sys_abs_list = list(sys_abs.indices())
for ind in sys_abs_list:
    h, k, l = ind
    if h in h2ind.keys() and k in k2ind.keys() and l in l2ind.keys():
        recon[0, 0, h2ind[h], k2ind[k], l2ind[l]] = 0

In [31]:
recon[0, 0, h2ind[0], k2ind[2], l2ind[2]]

tensor(0.1510)

In [44]:
mask = high != 0
recon2 = recon * mask

In [45]:
criterion(recon2, high).detach()

tensor(0.0106)

In [46]:
recon2 *= factor
high *= factor
criterion(recon2, high).detach()

tensor(1.8809)

In [75]:
#recon = recon2 ** 2
#high = high ** 2
criterion(recon2, high).detach()

tensor(107.5768)

In [None]:
recon[0, 0, 0]

In [None]:
high[0, 0, 0]

In [38]:
from iotbx import shelx
from iotbx.shelx import hklf
from iotbx.shelx import write_ins
import iotbx

In [39]:
f = open("synt1_high.hkl", "w")
hklf.miller_array_export_as_shelx_hklf(a_high.as_intensity_array(), file_object = f)
f.close()

In [40]:
f = open("synt1_low.hkl", "w")
hklf.miller_array_export_as_shelx_hklf(a_low.as_intensity_array(), file_object = f)
f.close()

In [41]:
res = {'low': 1.5, 'high': 0.8, 'recon': 0.8}
for name in ['low', 'high', 'recon']:
    f = open(f'synt1_{name}.ins', 'w')
    shelx.write_ins.shelxd(f, 'data' ,struc.crystal_symmetry(), 26, 'C H N O Cl', res[name])
    f.close()

In [42]:
recon = recon.squeeze(0).squeeze(0)

In [43]:
recon.shape

torch.Size([26, 18, 23])

In [44]:
dics = {'h': [{}, {}], 'k': [{}, {}], 'l': [{}, {}]}
for letter in 'hkl':
    for i in range(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1):
        dics[letter][1][i] = hkl_minmax[letter][0] + i #ind2h
        dics[letter][0][hkl_minmax[letter][0] + i] = i #h2ind

In [45]:
h2ind, ind2h = dics['h']
k2ind, ind2k = dics['k']
l2ind, ind2l = dics['l']

In [46]:
recon_ind, recon_int = [], []
for h in range(recon.shape[0]):
    for k in range(recon.shape[1]):
        for l in range(recon.shape[2]):
            if round(float(recon[h, k, l]), 2) != 0:
                recon_ind.append((ind2h[h], ind2k[k], ind2l[l]))
                recon_int.append(round(float(recon[h, k, l]), 2))
                
        

In [48]:
recon[h2ind[0], k2ind[2], l2ind[2]]

tensor(720.0053)

In [155]:
float(recon[0, 0, 0])

0.04829457029700279

In [49]:
recon_int = list(map(lambda x: str(x) + '0' if str(x)[-2]=='.' else str(x), recon_int))

In [54]:
recon_int

['    0.05',
 '    0.29',
 '    0.31',
 '    0.11',
 '    0.03',
 '    0.10',
 '    0.01',
 '    0.01',
 '    0.02',
 '    0.01',
 '    0.42',
 '    3.96',
 '    3.33',
 '    2.67',
 '    2.43',
 '    1.52',
 '    0.90',
 '    1.17',
 '    1.16',
 '    0.23',
 '    0.15',
 '    0.13',
 '    0.01',
 '    0.02',
 '    0.02',
 '    0.01',
 '    0.01',
 '    0.01',
 '    0.01',
 '    0.01',
 '    0.01',
 '    0.78',
 '    1.08',
 '    0.74',
 '    0.48',
 '    0.26',
 '    0.12',
 '    0.15',
 '    0.23',
 '    0.05',
 '    0.56',
 '    0.18',
 '    0.03',
 '    0.01',
 '    0.01',
 '    0.02',
 '    0.04',
 '    0.02',
 '    0.02',
 '    0.01',
 '    0.03',
 '    0.02',
 '    0.06',
 '    0.24',
 '    0.26',
 '    0.11',
 '    0.04',
 '    0.01',
 '    0.01',
 '    0.18',
 '    0.47',
 '    0.22',
 '    0.07',
 '    0.04',
 '    0.01',
 '    0.02',
 '    0.03',
 '    0.03',
 '    0.03',
 '    0.03',
 '    0.04',
 '    0.02',
 '    0.01',
 '    0.04',
 '    0.02',
 '    0.02',
 '    0.01',

In [55]:
max(map(len, recon_int))

8

In [56]:
recon_int = list(map(lambda x: ' '*(8-len(x))+x, recon_int))

In [57]:
hh, kk, ll = [], [], []
for ind in recon_ind:
    h, k, l = ind
    hh.append(str(h))
    kk.append(str(k))
    ll.append(str(l))

In [58]:
max(map(len, hh))

3

In [59]:
max(map(len, kk))

2

In [60]:
max(map(len, ll))

2

In [65]:
recon[h2ind[-3], k2ind[0], l2ind[2]]

tensor(10478.0527)

In [68]:
(low[0, 0, h2ind[-3], k2ind[0], l2ind[2]] * factor ) ** 2

tensor(10478.0527)

In [61]:
hh = list(map(lambda x: ' '*(3-len(x))+x, hh))
kk = list(map(lambda x: ' '*(3-len(x))+x, kk))
ll = list(map(lambda x: ' '*(3-len(x))+x, ll))

In [62]:
strings = []
for i in range(len(hh)):
    strings.append(' ' + hh[i] + ' '+ kk[i] + ' ' + ll[i] + recon_int[i] + '    0.01')

In [63]:
strings

[' -13   0   0    0.05    0.01',
 ' -13   0   2    0.29    0.01',
 ' -13   0   4    0.31    0.01',
 ' -13   0   6    0.11    0.01',
 ' -13   0   8    0.03    0.01',
 ' -13   0  10    0.10    0.01',
 ' -13   0  16    0.01    0.01',
 ' -13   0  18    0.01    0.01',
 ' -13   0  20    0.02    0.01',
 ' -13   0  22    0.01    0.01',
 ' -13   1   0    0.42    0.01',
 ' -13   1   1    3.96    0.01',
 ' -13   1   2    3.33    0.01',
 ' -13   1   3    2.67    0.01',
 ' -13   1   4    2.43    0.01',
 ' -13   1   5    1.52    0.01',
 ' -13   1   6    0.90    0.01',
 ' -13   1   7    1.17    0.01',
 ' -13   1   8    1.16    0.01',
 ' -13   1   9    0.23    0.01',
 ' -13   1  10    0.15    0.01',
 ' -13   1  11    0.13    0.01',
 ' -13   1  13    0.01    0.01',
 ' -13   1  14    0.02    0.01',
 ' -13   1  16    0.02    0.01',
 ' -13   1  17    0.01    0.01',
 ' -13   1  18    0.01    0.01',
 ' -13   1  19    0.01    0.01',
 ' -13   1  20    0.01    0.01',
 ' -13   1  21    0.01    0.01',
 ' -13   1

In [64]:
with open('synt1_recon.hkl', 'w') as f:
    for line in strings:
        f.write(f"{line}\n")

In [184]:
from iotbx.file_reader import any_file

In [99]:
strings = []
with open('1.hkl', 'r') as f:
    x = f.readlines()

In [100]:
xx = [i.split() for i in x]

In [101]:
xx

[['4', '0', '0', '973.22', '52.38', 'o'],
 ['6', '0', '0', '1033.83', '41.44', 'o'],
 ['3', '1', '0', '207.22', '16.06', 'o'],
 ['4', '1', '0', '1068.90', '37.15', 'o'],
 ['5', '1', '0', '1573.81', '70.05', 'o'],
 ['6', '1', '0', '397.75', '20.41', 'o'],
 ['7', '1', '020243.40', '366.83', 'o'],
 ['2', '2', '0', '6837.36', '167.29', 'o'],
 ['3', '2', '0', '57.10', '10.16', 'o'],
 ['4', '2', '0', '1133.52', '30.18', 'o'],
 ['5', '2', '0', '2078.84', '49.56', 'o'],
 ['6', '2', '0', '7425.86', '146.19', 'o'],
 ['7', '2', '0', '139.07', '16.28', 'o'],
 ['1', '3', '0', '4772.97', '111.36', 'o'],
 ['2', '3', '0', '33.01', '20.65', 'o'],
 ['3', '3', '0', '332.92', '19.77', 'o'],
 ['4', '3', '0', '770.40', '35.25', 'o'],
 ['5', '3', '0', '24.05', '16.89', 'o'],
 ['6', '3', '0', '11.07', '16.32', 'o'],
 ['0', '4', '036237.064195.13', 'o'],
 ['1', '4', '0', '6711.28', '136.75', 'o'],
 ['2', '4', '0', '6079.01', '124.58', 'o'],
 ['3', '4', '0', '167.09', '19.69', 'o'],
 ['4', '4', '0', '-1.99', '1

In [104]:
hh, kk, ll = [], [], []
inds, ints = [], []
for line in xx:
    if len(line) == 6:
        inds.append((int(line[0]), int(line[1]), int(line[2])))
        ints.append(float(line[3]))
    elif len(line) == 5:
        tmp = line[2]
        l = tmp[0]
        intt = tmp[1:]
        print(tmp, l, intt)
        inds.append((int(line[0]), int(line[1]), int(l)))
        ints.append(float(intt))
    else:
        print(line, line[2])
        tmp = line[2]
        l = tmp[0]
        intt = tmp[1:]
        inttt = intt.split('.')
        intt = inttt[0] + '.' + inttt[1][:2]
        print(tmp, l, intt)
        inds.append((int(line[0]), int(line[1]), int(l)))
        ints.append(float(intt))
        

020243.40 0 20243.40
['0', '4', '036237.064195.13', 'o'] 036237.064195.13
036237.064195.13 0 36237.06
014541.93 0 14541.93
116657.71 1 16657.71
118518.94 1 18518.94
117146.36 1 17146.36
110477.03 1 10477.03
112562.24 1 12562.24
130358.30 1 30358.30
210584.23 2 10584.23
217268.73 2 17268.73
['0', '3', '228110.882610.01', 'o'] 228110.882610.01
228110.882610.01 2 28110.88
226531.92 2 26531.92
213439.04 2 13439.04
230892.83 2 30892.83
313954.80 3 13954.80
310041.48 3 10041.48
416712.60 4 16712.60
['0', '0', '488518.052522.10', 'o'] 488518.052522.10
488518.052522.10 4 88518.05
['0', '1', '459467.841128.94', 'o'] 459467.841128.94
459467.841128.94 4 59467.84
['0', '2', '469619.541679.74', 'o'] 469619.541679.74
469619.541679.74 4 69619.54
411211.27 4 11211.27
410425.84 4 10425.84
530802.94 5 30802.94
510155.72 5 10155.72
['0', '0', '640487.271103.13', 'o'] 640487.271103.13
640487.271103.13 6 40487.27
610630.83 6 10630.83


IndexError: list index out of range

In [107]:
ints

[973.22,
 1033.83,
 207.22,
 1068.9,
 1573.81,
 397.75,
 20243.4,
 6837.36,
 57.1,
 1133.52,
 2078.84,
 7425.86,
 139.07,
 4772.97,
 33.01,
 332.92,
 770.4,
 24.05,
 11.07,
 36237.06,
 6711.28,
 6079.01,
 167.09,
 -1.99,
 33.35,
 1837.37,
 14541.93,
 410.01,
 1400.63,
 3357.71,
 200.92,
 7205.42,
 1824.25,
 67.78,
 21.7,
 2464.4,
 2357.39,
 518.03,
 22.11,
 51.33,
 16657.71,
 1873.18,
 753.14,
 74.34,
 79.12,
 11.05,
 3772.36,
 3408.76,
 1229.39,
 2111.37,
 540.73,
 209.39,
 22.11,
 1366.78,
 18518.94,
 2020.09,
 78.59,
 610.37,
 17146.36,
 2509.84,
 1401.55,
 10477.03,
 1307.01,
 1418.64,
 76.57,
 81.36,
 26.15,
 42.3,
 3829.67,
 17.91,
 351.24,
 36.54,
 933.85,
 605.77,
 12562.24,
 584.38,
 2867.67,
 12.73,
 526.32,
 3625.28,
 4865.5,
 1386.47,
 22.87,
 6370.35,
 2741.96,
 77.84,
 30358.3,
 39.03,
 762.72,
 1911.23,
 789.78,
 1001.49,
 4578.53,
 1661.03,
 162.11,
 764.87,
 425.99,
 6913.98,
 274.36,
 2278.11,
 1233.02,
 498.63,
 5069.21,
 953.15,
 1823.07,
 299.81,
 2760.11,
 1090.54

In [110]:
num = 0
size = [1, 1]
for letter in 'hkl':
    size.append(hkl_minmax[letter][1] - hkl_minmax[letter][0] + 1)
low = np.zeros(size)
for j, ind in enumerate(inds):
    h, k, l = ind
    if h not in h2ind.keys() or k not in k2ind.keys() or l not in l2ind.keys():
        #print(ind)
        num += 1
    else: low[0, 0, h2ind[h], k2ind[k], l2ind[l]] = ints[j]
print(f'{num} / {len(inds)} didnt fit')
low = np.sqrt(np.abs(low))
factor = low.max()
low /= factor
low = torch.from_numpy(low).float()

0 / 466 didnt fit


In [111]:
(low[0, 0, h2ind[0], k2ind[2], l2ind[8]] * factor) ** 2

tensor(849.8600)

In [112]:
low.unique()

tensor([0.0000, 0.0023, 0.0034, 0.0047, 0.0052, 0.0061, 0.0065, 0.0085, 0.0087,
        0.0094, 0.0095, 0.0097, 0.0101, 0.0101, 0.0107, 0.0108, 0.0110, 0.0112,
        0.0112, 0.0118, 0.0120, 0.0120, 0.0120, 0.0126, 0.0131, 0.0133, 0.0137,
        0.0139, 0.0140, 0.0142, 0.0145, 0.0148, 0.0151, 0.0153, 0.0156, 0.0157,
        0.0158, 0.0158, 0.0158, 0.0158, 0.0161, 0.0163, 0.0163, 0.0165, 0.0169,
        0.0170, 0.0171, 0.0172, 0.0172, 0.0175, 0.0182, 0.0183, 0.0184, 0.0187,
        0.0188, 0.0189, 0.0193, 0.0194, 0.0202, 0.0203, 0.0210, 0.0211, 0.0213,
        0.0217, 0.0219, 0.0220, 0.0228, 0.0233, 0.0237, 0.0241, 0.0246, 0.0246,
        0.0254, 0.0255, 0.0264, 0.0265, 0.0268, 0.0269, 0.0270, 0.0274, 0.0277,
        0.0287, 0.0290, 0.0291, 0.0294, 0.0295, 0.0296, 0.0297, 0.0298, 0.0299,
        0.0300, 0.0301, 0.0303, 0.0303, 0.0320, 0.0329, 0.0331, 0.0342, 0.0343,
        0.0343, 0.0350, 0.0358, 0.0372, 0.0374, 0.0374, 0.0389, 0.0390, 0.0391,
        0.0392, 0.0396, 0.0405, 0.0415, 

In [113]:
recon = nn.ReLU()(model(low)).detach()
mask = low == 0
recon = recon * mask + low

In [114]:
(recon[0, 0, h2ind[0], k2ind[2], l2ind[8]] * factor) ** 2

tensor(849.8600)

In [246]:
import cctbx
from cctbx import miller
from cctbx import crystal
from cctbx.array_family import flex

group = 'P21/n'
ms = miller.build_set(
                crystal_symmetry=crystal.symmetry(
                space_group_symbol=group,
                unit_cell=(30,30,30,90,90,90)),
                anomalous_flag=False, d_min = 0.8
            )
ms_base = ms.customized_copy(
                space_group_info = ms.space_group().build_derived_point_group().info())
ms_all = ms_base.complete_set()
sys_abs = ms_all.lone_set(other=ms_base)
sys_abs_list = list(sys_abs.indices())
for ind in sys_abs_list:
    h, k, l = ind
    if h in h2ind.keys() and k in k2ind.keys() and l in l2ind.keys():
        recon[0, 0, h2ind[h], k2ind[k], l2ind[l]] = 0

In [115]:
recon *= factor
recon = recon ** 2

In [116]:
recon[0, 0, h2ind[0], k2ind[2], l2ind[8]]

tensor(849.8600)

In [117]:
recon.shape

torch.Size([1, 1, 26, 18, 23])

In [118]:
recon = recon.squeeze(0).squeeze(0)

In [119]:
recon_ind, recon_int = [], []
for h in range(recon.shape[0]):
    for k in range(recon.shape[1]):
        for l in range(recon.shape[2]):
            if round(float(recon[h, k, l]), 2) != 0:
                recon_ind.append((ind2h[h], ind2k[k], ind2l[l]))
                recon_int.append(round(float(recon[h, k, l]), 2))

In [121]:
for i in range(len(recon_ind)):
    if recon_ind[i] == (0, 2, 8):
        print(i, recon_int[i])

5291 849.86


In [122]:
recon_int = list(map(lambda x: str(x) + '0' if str(x)[-2]=='.' else str(x), recon_int))

In [123]:
recon_int[5291]

'849.86'

In [265]:
max(map(len, recon_int))

7

In [124]:
recon_int = list(map(lambda x: ' '*(7-len(x))+x, recon_int))

In [126]:
recon_int[5291]

' 849.86'

In [127]:
hh, kk, ll = [], [], []
for ind in recon_ind:
    h, k, l = ind
    hh.append(str(h))
    kk.append(str(k))
    ll.append(str(l))

In [128]:
hh = list(map(lambda x: ' '*(3-len(x))+x, hh))
kk = list(map(lambda x: ' '*(3-len(x))+x, kk))
ll = list(map(lambda x: ' '*(3-len(x))+x, ll))

In [129]:
strings = []
for i in range(len(hh)):
    strings.append(' ' + hh[i] + ' '+ kk[i] + ' ' + ll[i] + recon_int[i] + '    0.01')

In [130]:
with open('1_recon.hkl', 'w') as f:
    for line in strings:
        f.write(f"{line}\n")