In [None]:
import atom3d.datasets as da
#da.download_dataset('lba', 'atom3d')


In [None]:

import argparse
import datetime
import json
import os
import time
import tqdm

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn as nn
from atom3d.datasets import LMDBDataset
from scipy.stats import spearmanr
random_seed=3
np.random.seed(random_seed)
torch.manual_seed(random_seed)

class CNN3D_LBA(nn.Module):
    def __init__(self, in_channels, spatial_size,
                 conv_drop_rate, fc_drop_rate,
                 conv_filters, conv_kernel_size,
                 max_pool_positions, max_pool_sizes, max_pool_strides,
                 fc_units,
                 batch_norm=True,
                 dropout=False):
        super(CNN3D_LBA, self).__init__()

        layers = []
        if batch_norm:
            layers.append(nn.BatchNorm3d(in_channels))

        # Convs
        for i in range(len(conv_filters)):
            layers.extend([
                nn.Conv3d(in_channels, conv_filters[i],
                          kernel_size=conv_kernel_size,
                          bias=True),
                nn.ReLU()
                ])
            spatial_size -= (conv_kernel_size - 1)
            if max_pool_positions[i]:
                layers.append(nn.MaxPool3d(max_pool_sizes[i], max_pool_strides[i]))
                spatial_size = int(np.floor((spatial_size - (max_pool_sizes[i]-1) - 1)/max_pool_strides[i] + 1))
            if batch_norm:
                layers.append(nn.BatchNorm3d(conv_filters[i]))
            if dropout:
                layers.append(nn.Dropout(conv_drop_rate))
            in_channels = conv_filters[i]

        layers.append(nn.Flatten())
        in_features = in_channels * (spatial_size**3)
        # FC layers
        for units in fc_units:
            layers.extend([
                nn.Linear(in_features, units),
                nn.ReLU()
                ])
            if batch_norm:
                layers.append(nn.BatchNorm3d(units))
            if dropout:
                layers.append(nn.Dropout(fc_drop_rate))
            in_features = units

        # Final FC layer
        layers.append(nn.Linear(in_features, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x).view(-1)

    
from atom3d.datasets import LMDBDataset
from atom3d.util.voxelize import dotdict, get_center, gen_rot_matrix, get_grid
from torch.utils.data import DataLoader

import dotenv as de
de.load_dotenv(de.find_dotenv(usecwd=True))


class CNN3D_TransformLBA(object):
    def __init__(self, random_seed=None, **kwargs):
        self.random_seed = random_seed
        self.grid_config =  dotdict({
            # Mapping from elements to position in channel dimension.
            'element_mapping': {
                'H': 0,
                'C': 1,
                'O': 2,
                'N': 3,
                'F': 4,
            },
            # Radius of the grids to generate, in angstroms.
            'radius': 20.0,
            # Resolution of each voxel, in angstroms.
            'resolution': 1.0,
            # Number of directions to apply for data augmentation.
            'num_directions': 20,
            # Number of rolls to apply for data augmentation.
            'num_rolls': 20,
        })
        # Update grid configs as necessary
        self.grid_config.update(kwargs)

    def _voxelize(self, atoms_pocket, atoms_ligand):
        # Use center of ligand as subgrid center
        ligand_pos = atoms_ligand[['x', 'y', 'z']].astype(np.float32)
        ligand_center = get_center(ligand_pos)
        # Generate random rotation matrix
        rot_mat = gen_rot_matrix(self.grid_config, random_seed=self.random_seed)
        # Transform protein/ligand into voxel grids and rotate
        grid = get_grid(pd.concat([atoms_pocket, atoms_ligand]),
                        ligand_center, config=self.grid_config, rot_mat=rot_mat)
        # Last dimension is atom channel, so we need to move it to the front
        # per pytroch style
        grid = np.moveaxis(grid, -1, 0)
        return grid

    def __call__(self, item):
        # Transform protein/ligand into voxel grids.
        # Apply random rotation matrix.
        transformed = {
            'feature': self._voxelize(item['atoms_pocket'], item['atoms_ligand']),
            'label': item['scores']['neglog_aff'],
            'id': item['id']
        }
        return transformed

def conv_model(in_channels, spatial_size, args):
    num_conv = args.num_conv
    conv_filters = [32 * (2**n) for n in range(num_conv)]
    conv_kernel_size = 3
    max_pool_positions = [0, 1]*int((num_conv+1)/2)
    max_pool_sizes = [2]*num_conv
    max_pool_strides = [2]*num_conv
    fc_units = [512]

    model = CNN3D_LBA(
        in_channels, spatial_size,
        args.conv_drop_rate,
        args.fc_drop_rate,
        conv_filters, conv_kernel_size,
        max_pool_positions,
        max_pool_sizes, max_pool_strides,
        fc_units,
        batch_norm=args.batch_norm,
        dropout=not args.no_dropout)
    return model

def train_loop(model, loader, optimizer, device):
    model.train()

    losses = []
    epoch_loss = 0
    progress_format = 'train loss: {:6.6f}'
    with tqdm.tqdm(total=len(loader), desc=progress_format.format(0)) as t:
        for i, data in enumerate(loader):
            feature = data['feature'].to(device).to(torch.float32)
            label = data['label'].to(device).to(torch.float32)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            output = model(feature)
            batch_losses = F.mse_loss(output, label, reduction='none')
            batch_losses_mean = batch_losses.mean()
            batch_losses_mean.backward()
            optimizer.step()
            # stats
            epoch_loss += (batch_losses_mean.item() - epoch_loss) / float(i + 1)
            losses.extend(batch_losses.tolist())
            t.set_description(progress_format.format(np.sqrt(epoch_loss)))
            t.update(1)

    return np.sqrt(np.mean(losses))
import pickle

def test(model, loader, device):
    model.eval()

    losses = []

    ids = []
    
    y_true = []
    y_pred = []
    with torch.no_grad():
        for data in loader:
            feature = data['feature'].to(device).to(torch.float32)
            label = data['label'].to(device).to(torch.float32)
            output = model(feature)
            batch_losses = F.mse_loss(output, label, reduction='none')
            losses.extend(batch_losses.tolist())
            ids.extend(data['id'])
            y_true.extend(label.tolist())
            y_pred.extend(output.tolist())

        results_df = pd.DataFrame(
            np.array([ids, y_true, y_pred]).T,
            columns=['structure', 'true', 'pred'],
            )
        r_p = np.corrcoef(y_true, y_pred)[0,1]
        r_s = spearmanr(y_true, y_pred)[0]

    return np.sqrt(np.mean(losses)), r_p, r_s, results_df
def save_weights(model, weight_dir):
    torch.save(model.state_dict(), weight_dir)
def train(args, device, test_mode=False):
    print("Training model with config:")
    print(str(json.dumps(args.__dict__, indent=4)) + "\n")

    # Save config
    with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
        json.dump(args.__dict__, f, indent=4)

    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),
                                transform=CNN3D_TransformLBA(random_seed=args.random_seed))
    val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                              transform=CNN3D_TransformLBA(random_seed=args.random_seed))
    test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                               transform=CNN3D_TransformLBA(random_seed=args.random_seed))

    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False)

    for data in train_loader:
        in_channels, spatial_size = data['feature'].size()[1:3]
        print('num channels: {:}, spatial size: {:}'.format(in_channels, spatial_size))
        break

    model = conv_model(in_channels, spatial_size, args)
    print(model)
    model.to(device)

    best_val_loss = np.Inf
    best_rp = 0
    best_rs = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.num_epochs+1):
        start = time.time()
        train_loss = train_loop(model, train_loader, optimizer, device)
        val_loss, r_p, r_s, val_df = test(model, val_loader, device)
        if val_loss < best_val_loss:
            print(f"\nSave model at epoch {epoch:03d}, val_loss: {val_loss:.4f}")
            save_weights(model, os.path.join(args.output_dir, f'best_weights.pt'))
            best_val_loss = val_loss
            best_rp = r_p
            best_rs = r_s
        elapsed = (time.time() - start)
        print('Epoch {:03d} finished in : {:.3f} s'.format(epoch, elapsed))
        print('\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(
            train_loss, val_loss, r_p, r_s))
        file=open("model_{}.p".format(epoch), 'wb')
        pickle.dump(model, file)
        file.close()
    if test_mode:
        model.load_state_dict(torch.load(os.path.join(args.output_dir, f'best_weights.pt')))
        rmse, pearson, spearman, test_df = test(model, test_loader, device)
        test_df.to_pickle(os.path.join(args.output_dir, 'test_results.pkl'))
        print('Test RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(
            rmse, pearson, spearman))
        test_file = os.path.join(args.output_dir, f'test_results.txt')
        with open(test_file, 'a+') as out:
            out.write('{}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(
                args.random_seed, rmse, pearson, spearman))

    return best_val_loss, best_rp, best_rs


In [None]:

class arguments():
    def __init__(self):
        self.data_dir='atom3d/split-by-sequence-identity-30/data'
        self.mode='train'
        self.output_dir='savedOutputs'
        self.unobserved=False
        self.learning_rate=.001
        self.conv_drop_rate=.1
        self.fc_drop_rate=.25
        self.num_epochs=50
        self.num_conv=4
        self.batch_norm=False
        self.no_dropout=False
        self.batch_size=16
        self.random_seed=3

args=arguments()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
#train(args, device, args.mode=='train')


In [None]:
file=open("model_{}.p".format(50), 'rb')
model=pickle.load(file)

In [None]:
print(model.model)

In [None]:

class arguments():
    def __init__(self):
        self.data_dir='atom3d/split-by-sequence-identity-30/data'
        self.mode='train'
        self.output_dir='savedOutputs'
        self.unobserved=False
        self.learning_rate=.001
        self.conv_drop_rate=.1
        self.fc_drop_rate=.25
        self.num_epochs=50
        self.num_conv=4
        self.batch_norm=False
        self.no_dropout=False
        self.batch_size=16
        self.random_seed=3

args=arguments()
val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                              transform=CNN3D_TransformLBA(random_seed=args.random_seed))

val_loader = DataLoader(val_dataset, 80, shuffle=False)


In [None]:
data=(next(iter(val_loader)))

In [None]:
x=data['feature']
y=data['label']

In [None]:
print(x.shape)

In [None]:
import copy
k=.7
import scipy
def getID(k, Z, W=None, mode='kT'):
    '''
    calculates ID 
    mode: kT, WT (what to return)
    requires (d, n) format
    currently in numpy because pytorch doesn't support QR pivots
    k: number of columns
    Z: layer after nonlinearity
    ''' 
    print(Z.shape)
    assert(k <= Z.shape[1])

    R, P = scipy.linalg.qr((Z), mode='r', pivoting=True)
    
    if W is not None: Wk = W[:, P[0:k]]
    T = np.concatenate((
        np.identity(k),
        np.linalg.pinv(R[0:k, 0:k]) @ R[0:k, k:None]
        ), axis=1)
    T = T[:, np.argsort(P)]
    if mode == 'kT':
        return P[0:k], T
    elif mode == 'WT' and W is not None:
        return Wk, T
    else:
        raise NotImplementedError

model.load_state_dict(torch.load(os.path.join(args.output_dir, f'best_weights.pt')))

import matplotlib.pyplot as plt
%matplotlib inline
import numpy.linalg as ln 



In [None]:
pruned=copy.deepcopy(model)

with torch.no_grad():
    Z=pruned.model[0](x.cuda())
    Z=pruned.model[1](Z)
    holder=copy.deepcopy(Z)
    fi = Z.shape[1]
    Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)
    fp=int(Zr.shape[1]*k)
    (k_idx,T) = getID(fp, Zr.cpu())
    sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)
plt.figure()
plt.semilogy(sv)
T = torch.Tensor(T)
with torch.no_grad():
    Wnext=pruned.model[3].weight.clone()
    saved=copy.deepcopy(Wnext)
    Wnext = Wnext.permute(0,2,3,4,1)
    Wnext = torch.matmul(Wnext.cpu(), T.T)
    Wnext = Wnext.permute(0,4,1,2,3)
pruned.model[3].weight=nn.Parameter(Wnext, requires_grad=True)
pruned.model[0].weight=nn.Parameter(pruned.model[0].weight[k_idx,:].clone(), requires_grad=True)
pruned.model[0].bias=nn.Parameter(pruned.model[0].bias[k_idx].clone(), requires_grad=True)
pruned.model[0].out_channels=fp
pruned.model[3].in_channels=fp

pruned.cuda()
Z=pruned.model[0](x.cuda())
Z=pruned.model[1](Z)
with torch.no_grad():
    Z=pruned.model[3](Z)
    Z=pruned.model[4](Z)
    holder=copy.deepcopy(Z)
    fi = Z.shape[1]
    Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)
    fp=int(Zr.shape[1]*k)
    (k_idx,T) = getID(fp, Zr.cpu())
    sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)
plt.figure()
file=open("secondLayerSingularValues.txt", 'w')
string='matrix size: [2026120, 64]\n'
for value in sv:
    string+=str(value)+'\n'
file.write(string)
file.close()
#np.savetxt("secondLayerMat.csv", Zr.cpu().numpy(), delimiter=",")
plt.semilogy(sv)
T = torch.Tensor(T)
with torch.no_grad():
    Wnext=pruned.model[7].weight.clone()
    saved=copy.deepcopy(Wnext)
    Wnext = Wnext.permute(0,2,3,4,1)
    Wnext = torch.matmul(Wnext.cpu(), T.T)
    Wnext = Wnext.permute(0,4,1,2,3)
pruned.model[7].weight=nn.Parameter(Wnext, requires_grad=True)
pruned.model[3].weight=nn.Parameter(pruned.model[3].weight[k_idx,:].clone(), requires_grad=True)
pruned.model[3].bias=nn.Parameter(pruned.model[3].bias[k_idx].clone(), requires_grad=True)
pruned.model[3].out_channels=fp
pruned.model[7].in_channels=fp


pruned.cuda()
Z=pruned.model[0](x.cuda())
Z=pruned.model[1](Z)
Z=pruned.model[3](Z)
Z=pruned.model[4](Z)
Z=pruned.model[5](Z)
with torch.no_grad():
    Z=pruned.model[7](Z)
    Z=pruned.model[8](Z)
    holder=copy.deepcopy(Z)
    fi = Z.shape[1]
    Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)
    fp=int(Zr.shape[1]*k)
    (k_idx,T) = getID(fp, Zr.cpu())
    sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)
plt.figure()
plt.semilogy(sv)
T = torch.Tensor(T)
with torch.no_grad():
    Wnext=pruned.model[10].weight.clone()
    saved=copy.deepcopy(Wnext)
    Wnext = Wnext.permute(0,2,3,4,1)
    Wnext = torch.matmul(Wnext.cpu(), T.T)
    Wnext = Wnext.permute(0,4,1,2,3)
pruned.model[10].weight=nn.Parameter(Wnext, requires_grad=True)
pruned.model[7].weight=nn.Parameter(pruned.model[7].weight[k_idx,:].clone(), requires_grad=True)
pruned.model[7].bias=nn.Parameter(pruned.model[7].bias[k_idx].clone(), requires_grad=True)
pruned.model[7].out_channels=fp
pruned.model[10].in_channels=fp




In [None]:
summary_input = (5,41,41, 41)


def prune(model, x, k=.9):
    pruned=copy.deepcopy(model)
    scores=[]
    
    flops=print_model_param_flops(pruned,input_res=41 )[1]
    print(flops)
    minScore=1
    #calculate scores
    with torch.no_grad():
        Z=pruned.model[0](x.cuda())
        Z=pruned.model[1](Z)
        fi = Z.shape[1]
        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)
        fp=int(Zr.shape[1]*k)
        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)
        scores.append(sv[fp]/sv[0]/(flops[0]+flops[1]))
        if scores[-1]<minScore:
            minScore=scores[-1]
            (k_idx,T) = getID(fp, Zr.cpu())
            T = torch.Tensor(T)
            currentLayer=0
            nextLayer=3
            f=fp
            
            
        Z=pruned.model[3](Z)
        Z=pruned.model[4](Z)
        fi = Z.shape[1]
        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)
        fp=int(Zr.shape[1]*k)
        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)
        scores.append(sv[fp]/sv[0]/(flops[1]+flops[2]))
        if scores[-1]<minScore:
            minScore=scores[-1]
            (k_idx,T) = getID(fp, Zr.cpu())
            T = torch.Tensor(T)
            currentLayer=3
            nextLayer=7
            f=fp
        
        Z=pruned.model[5](Z)
        Z=pruned.model[7](Z)
        Z=pruned.model[8](Z)
        fi = Z.shape[1]
        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)
        fp=int(Zr.shape[1]*k)
        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)
        scores.append(sv[fp]/sv[0]/(flops[2]+flops[3])) 
        if scores[-1]<minScore:
            minScore=scores[-1]
            (k_idx,T) = getID(fp, Zr.cpu())
            T = torch.Tensor(T)
            currentLayer=7
            nextLayer=10
            f=fp

        Z=pruned.model[10](Z)
        Z=pruned.model[11](Z)
        fi = Z.shape[1]
        Zr = Z.permute(0,2,3,4,1).reshape(-1,fi)
        fp=int(Zr.shape[1]*k)
        sv=ln.svd(Zr.cpu(), full_matrices=False, compute_uv=False)
        scores.append(sv[fp]/sv[0]/(flops[3])) 
        if scores[-1]<minScore:
            minScore=scores[-1]
            (k_idx,T) = getID(fp, Zr.cpu())
            T = torch.Tensor(T)
            currentLayer=10
            nextLayer=15
            f=fp
    print(T.shape)
    
    #prune layer
    if currentLayer!=10:
        with torch.no_grad():
            Wnext=pruned.model[nextLayer].weight.clone()
            Wnext = Wnext.permute(0,2,3,4,1)
            Wnext = torch.matmul(Wnext.cpu(), T.T)
            Wnext = Wnext.permute(0,4,1,2,3)
        pruned.model[nextLayer].weight=nn.Parameter(Wnext, requires_grad=True)
        pruned.model[currentLayer].weight=nn.Parameter(pruned.model[currentLayer].weight[k_idx,:].clone(), requires_grad=True)
        pruned.model[currentLayer].bias=nn.Parameter(pruned.model[currentLayer].bias[k_idx].clone(), requires_grad=True)
        pruned.model[currentLayer].out_channels=f
        pruned.model[nextLayer].in_channels=f
    else:
        n = int(pruned.model[nextLayer].in_features / pruned.model[currentLayer].out_channels)
        T = torch.kron(T.contiguous(), torch.eye(n))
        print(T.shape)
        pruned.model[nextLayer].weight = nn.Parameter(pruned.model[nextLayer].weight.cpu() @ T.T,
                    requires_grad=True)
        pruned.model[currentLayer].out_channels = f
        pruned.model[nextLayer].in_features = f * n
        pruned.model[currentLayer].weight=nn.Parameter(pruned.model[currentLayer].weight[k_idx,:].clone(), requires_grad=True)
        pruned.model[currentLayer].bias=nn.Parameter(pruned.model[currentLayer].bias[k_idx].clone(), requires_grad=True)
        pruned.model[currentLayer].out_channels=f

    #test
    
    rmse, pearson, spearman, test_df=test(pruned.cuda(), test_loader, device)
    return pruned,[rmse, pearson, spearman] , print_model_param_flops(pruned,input_res=41 )[0]
losses=[]
flops=[]
pruned=model
for i in range(0, 100):
    pruned, t, fs=prune(pruned, x)
    losses.append(t)
    flops.append(fs)
    print(t)
    print(fs)

In [None]:
#file=open("atom3dDump.p", 'wb')
import pickle
#pickle.dump([losses, flops], file)
file.close()

In [None]:
import matplotlib.pyplot as plt
rmse, pearson, spearman, test_df=test(model.cuda(), test_loader, device)
print(rmse)

In [None]:
ls=np.array(losses)
fs=np.array(flops)/print_model_param_flops(model,input_res=41 )[0]

In [None]:
plt.plot(fs, ls[:, 0])
plt.xlabel("flops")
plt.ylabel("loss")
plt.axhline(y=1.43, color='r', linestyle='-')

In [None]:
test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                              transform=CNN3D_TransformLBA(random_seed=args.random_seed))

test_loader = DataLoader(test_dataset, 16, shuffle=False)

model.eval()
test(pruned.cuda(), test_loader, device)

In [None]:
rmse, pearson, spearman, test_df = test(pruned.cuda(), test_loader, device)
print('Test RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(
            rmse, pearson, spearman
))

In [None]:
print(x.shape)

In [None]:
summary_input = (5,41,41, 41)

In [None]:
from train import model_summary

In [None]:
print(model_summary(model,summary_input= summary_input, input_res=41))

In [None]:
print(model_summary(pruned,summary_input= summary_input, input_res=41))

In [None]:
from torch.autograd import Variable
def print_model_param_flops(model=None, input_res=224, multiply_adds=True):

    prods = {}
    def save_hook(name):
        def hook_per(self, input, output):
            prods[name] = np.prod(input[0].shape)
        return hook_per

    list_1=[]
    def simple_hook(self, input, output):
        list_1.append(np.prod(input[0].shape))
    list_2={}
    def simple_hook2(self, input, output):
        list_2['names'] = np.prod(input[0].shape)

    list_conv=[]
    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width, input_depth = input[0].size()
        output_channels, output_height, output_width, output_depth = output[0].size()

        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]* (self.in_channels / self.groups)
        bias_ops = 1 if self.bias is not None else 0
        
        params = output_channels * (kernel_ops + bias_ops)
        flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width*output_depth * batch_size

        list_conv.append(flops)
        print(flops)
    list_linear=[]
    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1

        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
        bias_ops = self.bias.nelement()

        flops = batch_size * (weight_ops + bias_ops)
        list_linear.append(flops)

    list_bn=[]
    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement() * 2)

    list_relu=[]
    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())

    list_pooling=[]
    def pooling_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size * self.kernel_size
        bias_ops = 0
        params = 0
        flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size

        list_pooling.append(flops)

    list_upsample=[]
    # For bilinear upsample
    def upsample_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        flops = output_height * output_width * output_channels * batch_size * 12
        list_upsample.append(flops)

    def foo(net):
        childrens = list(net.children())
        if not childrens:
            if isinstance(net, torch.nn.Conv3d):
                net.register_forward_hook(conv_hook)
            if isinstance(net, torch.nn.Linear):
                net.register_forward_hook(linear_hook)
            if isinstance(net, torch.nn.BatchNorm2d):
                net.register_forward_hook(bn_hook)
            if isinstance(net, torch.nn.ReLU):
                net.register_forward_hook(relu_hook)
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                net.register_forward_hook(pooling_hook)
            if isinstance(net, torch.nn.Upsample):
                net.register_forward_hook(upsample_hook)
            return
        for c in childrens:
            foo(c)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    m = copy.deepcopy(model)
    foo(m)
    input = Variable(torch.rand(3, 5, input_res, input_res, input_res), requires_grad = True)
    input = input.to(device)
    out = m(input)

    print(list_conv, list_linear, list_bn, list_pooling, list_relu)
    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample))

    print('  + Number of FLOPs: %.5fG' % (total_flops / 3 / 1e9))

    return total_flops / 3, list_conv, list_linear

print_model_param_flops(model,input_res=41 )

In [None]:
print_model_param_flops(pruned,input_res=41 )

In [None]:
import torch.nn.utils.prune as prune
def conv_model(in_channels, spatial_size, args, k=1.0):
    num_conv = args.num_conv
    conv_filters = [int(32*k) * (2**n) for n in range(num_conv)]
    conv_kernel_size = 3
    max_pool_positions = [0, 1]*int((num_conv+1)/2)
    max_pool_sizes = [2]*num_conv
    max_pool_strides = [2]*num_conv
    fc_units = [512]

    model = CNN3D_LBA(
        in_channels, spatial_size,
        args.conv_drop_rate,
        args.fc_drop_rate,
        conv_filters, conv_kernel_size,
        max_pool_positions,
        max_pool_sizes, max_pool_strides,
        fc_units,
        batch_norm=args.batch_norm,
        dropout=not args.no_dropout)
    return model
amounts=[.05,.075,.1,.13, .15,.175, .2,.25, .3, .4,.5]
#rmses=[]
flops=[]

for data in test_loader:
    in_channels, spatial_size = data['feature'].size()[1:3]
    print('num channels: {:}, spatial size: {:}'.format(in_channels, spatial_size))
    break

        
for amount in amounts:
    dummy=conv_model(in_channels, spatial_size, args, k=1-amount)
    flops.append(print_model_param_flops(dummy.cuda(),input_res=41 )[0])
    new=copy.deepcopy(model)
    for name, module in new.named_modules():
        # prune 20% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv3d):
            prune.ln_structured(module, name='weight', amount=amount, dim=1, n=2)
    #rmses.append(test(new.cuda(), test_loader, device)[0])


In [None]:
file=open("atom3dmag.p", 'wb')
pickle.dump([flops, rmses],file)
file.close()

In [None]:
print_model_param_flops(new,input_res=41 )

In [None]:
print(flops)