In [33]:
import os
import numpy as np
import torch
from torch import nn
import gc
import time
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor 
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import pytz
import wandb

# Set to 'medium' or 'high'
torch.set_float32_matmul_precision('high')

class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(ResidualConvBlock, self).__init__()

        self.conv1 = nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=True)
        self.norm1 = nn.InstanceNorm3d(mid_channels)
        self.prelu1 = nn.PReLU()
        
        self.conv2 = nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, stride=1,  bias=True)
        self.norm2 = nn.InstanceNorm3d(out_channels)
        self.prelu2 = nn.PReLU()
        
        self.downsample = nn.Sequential()
        if in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv3d(in_channels, out_channels,  kernel_size=3, padding=1, stride=1,  bias=True),
                nn.InstanceNorm3d(out_channels)
            )
            
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.prelu1(out)
        out = self.conv2(out)
        out = self.norm2(out)
        residual = self.downsample(x) # if downsample true, otherwise residual = x
        out += residual
        out = self.prelu2(out)

        return out


class GeneratorNestedUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(GeneratorNestedUNet, self).__init__()
        
        # Nested UNet (UNet++), 4 levels
        num_channels = [32, 64, 128, 256, 512]

        # Depth 1
        self.conv0_0 = ResidualConvBlock(in_channels, num_channels[0], num_channels[0]) # input
        
        # Depth 2
        self.conv1_0 = ResidualConvBlock(num_channels[0], num_channels[1], num_channels[1]) # down 1
        self.conv0_1 = ResidualConvBlock(num_channels[0]+num_channels[1], num_channels[0], num_channels[0]) # up 1
        
        # Depth 3
        self.conv2_0 = ResidualConvBlock(num_channels[1], num_channels[2], num_channels[2]) # down 2
        self.conv1_1 = ResidualConvBlock(num_channels[1]+num_channels[2], num_channels[1], num_channels[1]) # up 1
        self.conv0_2 = ResidualConvBlock(num_channels[0]*2+num_channels[1], num_channels[0], num_channels[0]) # up 2

        # Depth 4
        self.conv3_0 = ResidualConvBlock(num_channels[2], num_channels[3], num_channels[3]) # down 3
        self.conv2_1 = ResidualConvBlock(num_channels[2]+num_channels[3], num_channels[2], num_channels[2]) # up 1
        self.conv1_2 = ResidualConvBlock(num_channels[1]*2+num_channels[2], num_channels[1], num_channels[1]) # up 2
        self.conv0_3 = ResidualConvBlock(num_channels[0]*3+num_channels[1], num_channels[0], num_channels[0]) # up 3
        
        self.final = nn.Conv3d(num_channels[0], out_channels, kernel_size=1, stride=1, padding=0, bias=True) # output
        
    def forward(self, x):
        up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        pool = nn.MaxPool3d(kernel_size=2, stride=2)
        # Downsampling
        x0_0 = self.conv0_0(x) 
        x1_0 = self.conv1_0(pool(x0_0)) # down 1
        x2_0 = self.conv2_0(pool(x1_0)) # down 2
        x3_0 = self.conv3_0(pool(x2_0)) # down 3
        # Upsampling
        x0_1 = self.conv0_1(torch.cat([x0_0, up(x1_0)], 1)) # up 1
        x1_1 = self.conv1_1(torch.cat([x1_0, up(x2_0)], 1)) # up 1
        x2_1 = self.conv2_1(torch.cat([x2_0, up(x3_0)], 1)) # up 1
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, up(x1_1)], 1)) # up 2
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, up(x2_1)], 1)) # up 2
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, up(x1_2)], 1)) # up 3
        out = self.final(x0_3)
        
        return out


class Discriminator(nn.Module):
    def __init__(self, in_channels=1, num_classes=1):
        super(Discriminator, self).__init__()
        
        num_channels = [32, 64, 128, 256, 512]
        
        self.prelu = nn.PReLU()
        self.relu = nn.ReLU()
        
        self.conv1 = nn.Conv3d(in_channels, num_channels[0], kernel_size=3, padding=1, stride=1, bias=True)
        self.norm1 = nn.InstanceNorm3d(num_channels[0])
        
        self.conv2 = nn.Conv3d(num_channels[0], num_channels[1], kernel_size=3, padding=1, stride=1, bias=True)
        self.norm2 = nn.InstanceNorm3d(num_channels[1])

        self.conv3 = nn.Conv3d(num_channels[1], num_channels[2], kernel_size=3, padding=1, stride=1, bias=True)
        self.norm3 = nn.InstanceNorm3d(num_channels[2])
        
        self.conv4 = nn.Conv3d(num_channels[2], num_channels[3], kernel_size=3, padding=1, stride=1, bias=True)
        # self.norm4 = nn.InstanceNorm3d(num_channels[3])
        # self.conv5 = nn.Conv3d(num_channels[3], num_channels[4], kernel_size=1, padding=0, stride=1, bias=True)

        # Adaptive pooling and linear layers for classification
        self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(num_channels[3], 100)
        self.fc2 = nn.Linear(100, 25)
        self.fc3 = nn.Linear(25, num_classes)
        
    def forward(self, x):
        x = self.prelu(self.norm1(self.conv1(x)))
        x = self.prelu(self.norm2(self.conv2(x)))
        x = self.prelu(self.norm3(self.conv3(x)))
        # x = self.prelu(self.norm4(self.conv4(x)))
        x = self.conv4(x)
        # pooling and pass to fc layers
        x = self.pool(x)
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
    

In [34]:
class GAN(L.LightningModule):
    def __init__(
        self,
        data_shape: tuple = (1, 32, 32, 32), # TODO
        lrG: float = 1,
        lrD: float = 1,
        b1: float = 0.5,
        b2: float = 0.999,
        alpha: float = 0.001,
        **kwargs,
        ):
        super(GAN, self).__init__()
        self.train_loss_each_epoch = []
        self.val_loss_each_epoch = []
        self.save_hyperparameters()
        self.automatic_optimization = False
                
        # Generator
        self.generator = GeneratorNestedUNet(in_channels=data_shape[0], out_channels=data_shape[0])
        self.discriminator = Discriminator(in_channels=data_shape[0], num_classes=1)
        # normal initialization        
        self.generator.apply(self.weights_init_normal)
        self.discriminator.apply(self.weights_init_normal)

    def forward(self, x):
        return self.generator(x)

    @staticmethod
    def adversarial_loss(y_pred, y_target):
        return nn.BCEWithLogitsLoss()(y_pred, y_target)
    
    @staticmethod
    def voxelwise_loss(y_pred, y_target):
        return nn.SmoothL1Loss()(y_pred, y_target)    
    
    @staticmethod
    def weights_init_normal(m):
        if isinstance(m, nn.Conv3d) or isinstance(m, nn.InstanceNorm3d): # we don't set weights for InstanceNorm3d layer (affine=False)
            if hasattr(m, 'weight') and m.weight is not None:
                torch.nn.init.normal_(m.weight.data, mean=0.0, std=0.01)
            if hasattr(m, 'bias') and m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
           
    def training_step(self, batch):
        x, y = batch
        batch_size = x.shape[0]
        
        optimizer_G, optimizer_D = self.optimizers()
        
        # make ground truth labels
        real_label = torch.ones((batch_size, 1), dtype=torch.float, device=self.device, requires_grad=False)
        fake_label = torch.zeros((batch_size, 1), dtype=torch.float, device=self.device, requires_grad=False)
        
        ###################
        # train Generator #
        ###################
        self.toggle_optimizer(optimizer_G)
        
        gen_imgs = self.generator(x)
        
        voxel_loss = self.voxelwise_loss(gen_imgs, y)
        genadv_loss = self.adversarial_loss(self.discriminator(gen_imgs), real_label)
        loss_G = voxel_loss + self.hparams.alpha * genadv_loss 
        
        optimizer_G.zero_grad()
        self.manual_backward(loss_G)
        optimizer_G.step()
        self.untoggle_optimizer(optimizer_G)
    
        #######################
        # train Discriminator #
        #######################
        self.toggle_optimizer(optimizer_D)
        
        real_loss = self.adversarial_loss(self.discriminator(y), real_label)
        fake_loss = self.adversarial_loss(self.discriminator(gen_imgs.detach()), fake_label)    
        loss_D = (real_loss + fake_loss) / 2
        
        optimizer_D.zero_grad()
        self.manual_backward(loss_D)
        optimizer_D.step()
        self.untoggle_optimizer(optimizer_D)
        
        self.log_dict(
            {
            'loss_G': loss_G,
            'loss_D': loss_D,
            'loss_voxelL1': voxel_loss,
            'loss_genadvBCE': genadv_loss,
            'loss_realD': real_loss,
            'loss_fakeD': fake_loss,
            },
            prog_bar=False,
            on_step=False,
            on_epoch=True,
        )
    
    # validation step
    def validation_step(self, batch):
        x, y = batch
        batch_size = x.shape[0]
        
        gen_imgs = self.generator(x)
        
        # make ground truth labels
        real_label = torch.ones((batch_size, 1), dtype=torch.float, device=self.device, requires_grad=False)
        fake_label = torch.zeros((batch_size, 1), dtype=torch.float, device=self.device, requires_grad=False)
        
        voxel_loss = self.voxelwise_loss(gen_imgs, y)
        genadv_loss = self.adversarial_loss(self.discriminator(gen_imgs), real_label)
        loss_G = voxel_loss + self.hparams.alpha * genadv_loss 
        
        real_loss = self.adversarial_loss(self.discriminator(y), real_label)
        fake_loss = self.adversarial_loss(self.discriminator(gen_imgs.detach()), fake_label)    
        loss_D = (real_loss + fake_loss) / 2
        
        self.log_dict(
            {
            # 'epoch': self.current_epoch,
            'val_loss_G': loss_G,
            'val_loss_D': loss_D,
            'val_loss_voxelL1': voxel_loss,
            'val_loss_genadvBCE': genadv_loss,
            'val_loss_realD': real_loss,
            'val_loss_fakeD': fake_loss,
            'val_loss': loss_G + loss_D, # only for EarlyStopping
            },
            prog_bar=False,
            on_step=False,
            on_epoch=True,
        )   

    def configure_optimizers(self):
        optimizer_G = torch.optim.NAdam(self.generator.parameters(), lr=self.hparams.lrG)
        optimizer_D = torch.optim.NAdam(self.discriminator.parameters(), lr=self.hparams.lrD)
        
        scheduler_G = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode='min', factor=0.1, 
                                                               patience=1, eps=1e-10,)
        scheduler_D = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, mode='min', factor=0.1, 
                                                               patience=5, eps=1e-10)
        
        return [
            {"optimizer": optimizer_G, "lr_scheduler": scheduler_G, "monitor": "val_loss_G"},
            {"optimizer": optimizer_D, "lr_scheduler": scheduler_D, "monitor": "val_loss_D"},
        ]

        # scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=2, gamma=0.1)
        # scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=5, gamma=0.1)
        # return [
        #     {"optimizer": optimizer_G, "lr_scheduler": scheduler_G, "interval": "epoch", "frequency": 1},
        #     {"optimizer": optimizer_D, "lr_scheduler": scheduler_D, "interval": "epoch", "frequency": 1},
        # ]
        
    
    def on_train_epoch_end(self):
        # clear memory
        torch.cuda.empty_cache() # clear gpu
        gc.collect() # clear cpu
    
    def on_validation_epoch_end(self):
        # Step the schedulers
        val_loss_G = self.trainer.logged_metrics['val_loss_G']
        val_loss_D = self.trainer.logged_metrics['val_loss_D']
        for lr_scheduler in self.lr_schedulers():
            lr_scheduler.step(val_loss_G)
            lr_scheduler.step(val_loss_D)
         
        # clear memory
        torch.cuda.empty_cache()
        gc.collect()
        

In [35]:
# File naming
def runName():
    # Get UTC time and convert to Pacific Time
    now_utc = datetime.now(pytz.timezone('UTC'))
    now_pst = now_utc.astimezone(pytz.timezone('US/Pacific'))
    now_str = now_pst.strftime('%y-%m%d-%H%M%S')
    # create the log savepath
    fpath = f'./checkpoints/gan_model/{now_str}'
    os.makedirs(fpath, exist_ok=True)
    
    return fpath, now_str      


# Callbacks
class EpochPrintCallback(L.Callback):
    def on_train_start(self, trainer, pl_module, column_width=12):
        self.column_width = column_width
        title = "|"
        title += "epoch".center(self.column_width) + "|"
        title += "time".center(self.column_width) + "|"
        title += "loss_G".center(self.column_width) + "|"
        title += "loss_D".center(self.column_width) + "|"
        title += "val_loss_G".center(self.column_width) + "|"
        title += "val_loss_D".center(self.column_width) + "|"
        self.row = "-" * len(title)
        print(self.row)
        print(title)
        print(self.row)
    
    def on_train_epoch_start(self, trainer, pl_module):
        self.start_time = time.time()

    def on_train_epoch_end(self, trainer, pl_module):
        epoch_time = time.time() - self.start_time
        output = "\r|"
        output += f"{trainer.current_epoch}".center(self.column_width) + "|"
        output += f"{epoch_time:.2f}".center(self.column_width) + "|"
        output += f"{trainer.logged_metrics['loss_G']:.4f}".center(self.column_width) + "|"
        output += f"{trainer.logged_metrics['loss_D']:.4f}".center(self.column_width) + "|"
        output += f"{trainer.logged_metrics['val_loss_G']:.4f}".center(self.column_width) + "|"
        output += f"{trainer.logged_metrics['val_loss_D']:.4f}".center(self.column_width) + "|"
        print(output)
        print(self.row) 
        
    def on_train_end(self, trainer, pl_module):
        done = "| Training completed on epoch " + str(trainer.current_epoch-1) + " / " + str(trainer.max_epochs-1) + " |"
        row = "-" * len(done)
        print("\n")
        print(row)
        print(done)
        print(row)

In [36]:
## FAKE DATA

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split

class MyDataset(Dataset):
    def __init__(self, num_samples, img_size):
        self.num_samples = num_samples
        self.img_size = img_size
        
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if idx >= self.__len__():
            raise IndexError('Index out of range')
        # Generate a fake 3D image and label
        np.random.seed(0)
        img = np.random.rand(*self.img_size)
        np.random.seed(42)
        label = np.random.rand(*self.img_size)

        # Convert to PyTorch tensors
        img = torch.from_numpy(img).float().unsqueeze(0)  
        label = torch.from_numpy(label).float().unsqueeze(0)  

        return img, label

# Define the size of the images (depth, height, width)
img_size = (32, 32, 32)

# Instantiate dataset
dataset = MyDataset(num_samples=100, img_size=img_size)
train_dataset, val_dataset = random_split(dataset, [80, 20])

# Instantiate dataloader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=8)

In [None]:
class LearningRateMonitorCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        for i, lr_scheduler_config in enumerate(trainer.lr_scheduler_configs):
            scheduler = lr_scheduler_config.scheduler
            print(scheduler.optimizer.param_groups[0])
            current_lr = scheduler.optimizer.param_groups[0]['lr']
            if i == 0:
                self.log('lr_G', current_lr, sync_dist=True)
            elif i == 1:
                self.log('lr_D', current_lr, sync_dist=True)


fpath, now_str = runName()

logger = WandbLogger(
    project='GAN',
    save_dir=fpath, 
    name=now_str, 
    )
early_stop = EarlyStopping(
    monitor='val_loss', 
    patience=20, 
    verbose=False
    )
lr_monitor = LearningRateMonitor(
    logging_interval='epoch'
    )
checkpoints = ModelCheckpoint(
    dirpath=fpath,
    filename='{epoch}-{val_loss_G:.4f}-{val_loss_D:.4f}',
    monitor='val_loss',
    save_top_k=1,
    save_last=True,
    )

model = GAN()
trainer = L.Trainer(
    accelerator="gpu", 
    devices=1, 
    max_epochs=20, 
    log_every_n_steps=1, 
    logger=logger,
    enable_progress_bar=False,
    callbacks=[
        EpochPrintCallback(),
        # lr_monitor,
        # early_stop,
        checkpoints,
        LearningRateMonitorCallback()
        ],
    )
trainer.fit(model, train_dataloader, val_dataloader)

# Finish any existing WandB run
wandb.finish()

In [38]:
trainer.logged_metrics

{'val_loss_G': tensor(15.1391),
 'val_loss_D': tensor(0.6981),
 'val_loss_voxelL1': tensor(15.1385),
 'val_loss_genadvBCE': tensor(0.5985),
 'val_loss_realD': tensor(0.5985),
 'val_loss_fakeD': tensor(0.7977),
 'val_loss': tensor(15.8372),
 'loss_G': tensor(15.1391),
 'loss_D': tensor(0.6981),
 'loss_voxelL1': tensor(15.1385),
 'loss_genadvBCE': tensor(0.5985),
 'loss_realD': tensor(0.5985),
 'loss_fakeD': tensor(0.7977),
 'lr_G': tensor(1.0000e-10),
 'lr_D': tensor(1.0000e-06)}

In [39]:
trainer.optimizers

[NAdam (
 Parameter Group 0
     betas: (0.9, 0.999)
     capturable: False
     decoupled_weight_decay: False
     differentiable: False
     eps: 1e-08
     foreach: None
     lr: 1.0000000000000006e-10
     momentum_decay: 0.004
     weight_decay: 0
 ),
 NAdam (
 Parameter Group 0
     betas: (0.9, 0.999)
     capturable: False
     decoupled_weight_decay: False
     differentiable: False
     eps: 1e-08
     foreach: None
     lr: 1.0000000000000004e-06
     momentum_decay: 0.004
     weight_decay: 0
 )]

In [43]:
scheduler  = trainer.lr_scheduler_configs[0].scheduler
current_lr = scheduler.optimizer.param_groups[0]['lr']
current_lr

1.0000000000000006e-10

In [47]:
scheduler.optimizer.param_groups[0]

{'params': [Parameter containing:
  tensor([[[[[-0.4123,  0.8182,  1.1537],
             [-0.6044,  1.1536,  1.1491],
             [ 2.3473, -0.0456,  1.6257]],
  
            [[-0.5899,  1.7960, -1.3317],
             [-2.9784,  0.6111,  0.9009],
             [-2.0062,  0.0508, -3.3020]],
  
            [[-0.8003,  0.7816, -2.7513],
             [-3.1621, -0.9283, -0.8631],
             [-1.3250,  0.9783,  2.7081]]]],
  
  
  
          [[[[-3.1674,  1.1726, -3.2069],
             [-2.8701,  0.2328, -1.1006],
             [-2.9656,  1.0066, -1.2747]],
  
            [[-1.5737, -0.5812, -1.2039],
             [-1.0155,  1.5639, -0.5743],
             [-3.3872,  1.6449, -1.4261]],
  
            [[ 0.8120,  0.4860, -3.1878],
             [ 0.9481, -0.5352, -0.9344],
             [ 0.3161,  2.5466, -0.9952]]]],
  
  
  
          [[[[ 3.1073,  3.3089,  1.1765],
             [-1.3057, -2.1862, -1.1706],
             [-0.0178,  0.0347, -1.1730]],
  
            [[ 3.1983,  3.2183,  1.1543]

In [41]:
scheduler.optimizer.param_groups[0]

{'params': [Parameter containing:
  tensor([[[[[-0.4123,  0.8182,  1.1537],
             [-0.6044,  1.1536,  1.1491],
             [ 2.3473, -0.0456,  1.6257]],
  
            [[-0.5899,  1.7960, -1.3317],
             [-2.9784,  0.6111,  0.9009],
             [-2.0062,  0.0508, -3.3020]],
  
            [[-0.8003,  0.7816, -2.7513],
             [-3.1621, -0.9283, -0.8631],
             [-1.3250,  0.9783,  2.7081]]]],
  
  
  
          [[[[-3.1674,  1.1726, -3.2069],
             [-2.8701,  0.2328, -1.1006],
             [-2.9656,  1.0066, -1.2747]],
  
            [[-1.5737, -0.5812, -1.2039],
             [-1.0155,  1.5639, -0.5743],
             [-3.3872,  1.6449, -1.4261]],
  
            [[ 0.8120,  0.4860, -3.1878],
             [ 0.9481, -0.5352, -0.9344],
             [ 0.3161,  2.5466, -0.9952]]]],
  
  
  
          [[[[ 3.1073,  3.3089,  1.1765],
             [-1.3057, -2.1862, -1.1706],
             [-0.0178,  0.0347, -1.1730]],
  
            [[ 3.1983,  3.2183,  1.1543]

In [38]:
lr_monitor.lrs

{'lr-NAdam': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'lr-NAdam-1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [20]:
trainer.lr_scheduler_configs[0].scheduler

<torch.optim.lr_scheduler.ReduceLROnPlateau at 0x7faae929fd00>

In [1]:
import os
import sys
import subprocess
import numpy as np
import mrcfile
from copy import deepcopy

In [42]:
class PreprocessMap():
    def __init__(self, mappath, save_dir, emd_id, CHIMERAX_PATH='/usr/bin/chimerax', verbose=False, rm_interim=True):
        self.chimerax_path = CHIMERAX_PATH
        self.path = mappath
        self.save_dir = save_dir
        self.verbose = verbose
        self.rm_interim = rm_interim
        self.ogmap_path = os.path.join(self.path, f'emd_{emd_id}.map')
        self.resample_path = os.path.join(self.save_dir, f'{emd_id}_resample.mrc')
        self.norm_path = os.path.join(self.save_dir, f'{emd_id}_norm.mrc')
        self.sim_path = os.path.join(self.save_dir, f'{emd_id}_sim.mrc')
        self.sim_norm_path = os.path.join(self.save_dir, f'{emd_id}_sim_norm.mrc')
    
    
    @staticmethod
    def normalization(cleanmap, output_path, percent=99):
        # normalize by percentile value
        mrc_data = deepcopy(cleanmap.data)   
        percentile = np.percentile(mrc_data[np.nonzero(mrc_data)], percent)
        mrc_data /= percentile

        # set value < 0 to 0; value > 1 to 1
        mrc_data[mrc_data < 0] = 0
        mrc_data[mrc_data > 1] = 1
        
        # write to new mrc file
        with mrcfile.new(f'{output_path}', overwrite=True) as mrc:
            mrc.set_data(mrc_data)
            mrc.voxel_size = 1
            mrc.header.origin = cleanmap.header.origin
            mrc.close()

        
    # Resample original maps to 1.0 A/voxel
    def resampleMap(self, emd_id):
        # Execute ChimeraX resampling
        result = subprocess.run([self.chimerax_path, '--nogui', 
                                '--cmd', 
                                f'open {self.ogmap_path}; \
                                vol #1 step 1 ; \
                                vol resample #1 spacing 1.0; \
                                vol #2 step 1 ; \
                                save {self.resample_path} #2; \
                                exit'],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                text=True)
        if self.verbose:
            print(result.stdout)
    
    
    # Make normalized original map
    def normalizeMap(self, emd_id):
        # Normalize
        cleanmap = mrcfile.open(f'{self.resample_path}', mode='r')
        self.normalization(cleanmap, self.norm_path)
            
        if self.verbose:
            print(f'Normalized map saved as {self.norm_path}')
        
        
    # Make simulation map
    def makeSimMap(self, emd_id, pdb_id):
        # Execute ChimeraX molmap with resolution=2.0
        result = subprocess.run([self.chimerax_path, '--nogui', 
                                '--cmd', 
                                f'open {self.norm_path}; \
                                open {self.path}/{pdb_id}_ref.pdb; \
                                vol #1 step 1 ; \
                                molmap #2 2.0 onGrid #1; \
                                vol #3 step 1 ; \
                                save {self.sim_path} #3; \
                                exit'],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                text=True)
        if self.verbose:
            print(result.stdout)
            
        # Make normalized map
        cleanmap = mrcfile.open(f'{self.sim_path}', mode='r')
        self.normalization(cleanmap, self.sim_norm_path)
        
        if self.verbose:
            print(f'Normalized Simulated map saved as {self.sim_norm_path}')
        
        
    # Process the map
    def process(self, emd_id, pdb_id):
        self.resampleMap(emd_id)
        self.normalizeMap(emd_id)
        self.makeSimMap(emd_id, pdb_id)
        
        # delete interim maps
        if self.rm_interim:
            os.remove(f'{self.resample_path}')
            os.remove(f'{self.sim_path}')
        
        if self.verbose:
            print(f'Resampled map removed: {self.resample_path}')
            print(f'Simulated map removed: {self.sim_path}')
            
        print(f'Processing of {emd_id} completed.')
        

In [None]:
emd_id = '8685'
pdb_id = '5vhw'
preprocess = PreprocessMap(mappath='./data/raw_gan_data', save_dir='./', verbose=False)

In [35]:
preprocess.process(emd_id=emd_id, pdb_id=pdb_id)

Processing of 8685 completed.


In [None]:
import mrcfile
import numpy as np
from scipy.ndimage import center_of_mass

emdid = 26801 
map1 = f'./data/processed_gan_data/{emdid}_norm.mrc'
map2 = f'./data/processed_gan_data/{emdid}_sim_norm.mrc'

# Load the full chain map and the one-chain map
with mrcfile.open(map1) as mrc_1, mrcfile.open(map2) as mrc_2:
    data_full = mrc_1.data
    data_one = mrc_2.data

# Normalize both maps to a common scale
data_full = data_full / np.max(data_full)
data_one = data_one / np.max(data_one)

# Compute centers of mass as a quick similarity check
center_full = center_of_mass(data_full)
center_one = center_of_mass(data_one)

# Calculate a threshold (adjust as needed)
similarity_threshold = 0.01

# Compare centers of mass to detect similarity/difference
distance = np.linalg.norm(np.array(center_full) - np.array(center_one))

# Consider maps similar if the centers are within a small distance
if distance < similarity_threshold:
    print("The two maps are similar (all chains).")
else:
    print("The two maps are different (single chain vs. multiple chains).")


### Cubes

In [2]:
from map2cube import *
import mrcfile

In [35]:
emdid = '0590' 
map1 = f'./data/{emdid}_norm.mrc'
mdata = mrcfile.open(map1, mode='r').data
mdata.shape, map1

((207, 207, 207), './data/0590_norm.mrc')

In [36]:
mcube = np.array(create_cube(mdata, box_size=32, core_size=20))

In [37]:
mcube.shape, mcube[0].shape, len(mcube)

((1331, 32, 32, 32), (32, 32, 32), 1331)

In [14]:
remap = reconstruct_map(mcube, image_shape=mdata.shape, box_size=32, core_size=20)
remap.shape

(207, 207, 207)

In [13]:
np.all(remap == mdata)

True

In [None]:
write2map(map1, remap, save_dir='.', verbose=True)

In [None]:
write2map(map1, remap, save_path='./')

### MRC read

In [61]:
import mrcfile
import numpy as np

map1 = './output/nogrid.mrc'
map2 = './output/grid.mrc'
map3 = './output/emd_0346.map'


dmap1 = mrcfile.open(map1, mode='r+')
dmap2 = mrcfile.open(map2, mode='r')
dmap3 = mrcfile.open(map3, mode='r')


In [52]:
dmap1.header.cella, dmap2.header.cella, dmap3.header.cella

(rec.array((117.333336, 80., 182.66667),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]),
 rec.array((271.36, 271.36, 271.36),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]),
 rec.array((271.36, 271.36, 271.36),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]))

In [53]:
dmap1.header.origin, dmap2.header.origin, dmap3.header.origin

(rec.array((77.202, 95.771, 52.989),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]),
 rec.array((0., 0., 0.),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]),
 rec.array((0., 0., 0.),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]))

In [54]:
dmap1.data.shape, dmap2.data.shape, dmap3.data.shape

((274, 120, 176), (256, 256, 256), (256, 256, 256))

In [55]:
dmap1.voxel_size, dmap2.voxel_size, dmap3.voxel_size

(rec.array((0.6666667, 0.6666667, 0.6666667),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]),
 rec.array((1.06, 1.06, 1.06),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]),
 rec.array((1.06, 1.06, 1.06),
           dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')]))

In [58]:
from copy import deepcopy

In [59]:
deepcopy(dmap1.data)

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [70]:
import subprocess
chimerax = '/usr/bin/chimerax'
ref_map = './output/emd_0346.map'
sim_map = './app/6n52_ref_norm_molmap.mrc'
gan_map = './app/6n52_ref_norm_molmap_gan.mrc'



def resanple(chimerax, ref_map, sim_map, gan_map, verbose=False):
    result = subprocess.run([chimerax, '--nogui', 
                            '--cmd', 
                            f'open {ref_map}; \
                            open {sim_map}; \
                            open {gan_map}; \
                            vol #1 #2 #3 step 1 ; \
                            vol resample #2 onGrid #1 gridStep 1; \
                            vol resample #3 onGrid #1 gridStep 1; \
                            save /u01/44.mrc #4; \
                            save /u01/55.mrc #5; \
                            exit'],
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE,
                            text=True)

In [71]:
resanple(chimerax, ref_map, sim_map, gan_map, verbose=True)