In [None]:
from pointnet.pn_dataset import PointNetDataset
from pointnet.model import PointNetLayer, PointNet

In [None]:
from torch_geometric.loader import DataLoader
import numpy as np
import glob
import torch
import pandas as pd
from torch_cluster import knn_graph
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
import torch

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device == 'cpu':
    print('Warning, training on cpu')

### Load data 

In [None]:
root_dir = '/home/dim26fa/data/imod_models/symmetry/'
train_pc = glob.glob(root_dir + '/*/preprocessed/train/*/*Localizations*.csv')
test_pc = glob.glob(root_dir + '/*/preprocessed/test/*/*Localizations*.csv')

In [None]:
from torch_geometric.transforms import Compose, RandomRotate

torch.manual_seed(123)

random_rotate = Compose([
    RandomRotate(degrees=180, axis=0),
    RandomRotate(degrees=180, axis=1),
    RandomRotate(degrees=180, axis=2),
])


In [None]:
train_dataset = PointNetDataset(train_pc, transform=random_rotate)
test_dataset = PointNetDataset(test_pc, transform=random_rotate)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10)

In [None]:
for data in train_loader:
    print(data.y)

### Try another data class

In [None]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

train_list = []
for idx, path in enumerate(train_pc):
    processed_dir = os.path.dirname(path)
    df = pd.read_csv(path)
    arr = np.array(df)
    tens = torch.tensor(arr, dtype=torch.float)
        #edge_index = knn_graph(tens, k=6)
    class_name = path.split('/')[-5]
    if class_name == 'npc_6fold':
        label = 0
    else:
        label = 1
    data = Data(pos=tens,
                #edge_index=edge_index,
                y=torch.tensor(label, dtype=torch.int64)
                )
    torch.save(data, osp.join(processed_dir, f'data_{idx}.pt'))
    train_list.append(data)

train_loader = DataLoader(train_list, batch_size=20, shuffle=True)

In [None]:
test_list = []
for idx, path in enumerate(test_pc):
    processed_dir = os.path.dirname(path)
    df = pd.read_csv(path)
    arr = np.array(df)
    tens = torch.tensor(arr, dtype=torch.float)
        #edge_index = knn_graph(tens, k=6)
    class_name = path.split('/')[-5]
    if class_name == 'npc_6fold':
        label = 0
    else:
        label = 1
    data = Data(pos=tens,
                #edge_index=edge_index,
                y=torch.tensor(label, dtype=torch.int64)
                )
    torch.save(data, osp.join(processed_dir, f'data_{idx}.pt'))
    test_list.append(data)

test_loader = DataLoader(test_list, batch_size=10)

In [None]:
for i in train_loader:
    print(i.y)

In [None]:
train_loader.dataset[0].batch

In [None]:
dirnames = []
for pc in test_pc:
    dirnames.append(os.path.dirname(pc))

In [None]:
for folder in dirnames:
    files = glob.glob(folder + '/*.pt')
    for file in files: 
        os.remove(file)

### Load model

In [None]:
writer = SummaryWriter()

In [None]:
model = PointNet()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.

### Training and testing functions 

In [None]:
def train(model, optimizer, loader):
    model.train()

    total_loss = 0
    for data in loader:
        optimizer.zero_grad()  # Clear gradients.
        logits = model(data.pos, data.batch)  # Forward pass.
        loss = criterion(logits, data.y)  # Loss computation.
        loss.backward()  # Backward pass.
        optimizer.step()  # Update model parameters.
        total_loss += loss.item() * data.num_graphs
        writer.add_scalar('training loss',
                            total_loss/len(train_loader.dataset))
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(model, loader):
    model.eval()

    total_correct = 0
    mis=[]
    for data in loader:
        logits = model(data.pos, data.batch)
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())
        for i, label in enumerate(data.y):
            if pred[i]!=label:
                mis.append(data[i].pos)
    return total_correct / len(loader.dataset), mis

### Training and testing procedure

In [None]:
losses = []
accuracies = []
misclassified = []
for epoch in range(1, 800):
    loss = train(model, optimizer, train_loader)
    losses.append(loss)
    test_acc, miss = test(model, test_loader)
    if epoch>600:
        misclassified.append(miss)
    accuracies.append(test_acc)     
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')
writer.flush()

In [None]:
writer.close()

In [None]:
torch.save(model.state_dict(), '/home/dim26fa/coding/smlm_v2/models/11082022_pointnet_symmetry_6fold8fold/'+'.pth')

In [None]:
misclassified[0][1]

In [None]:
k = 0
for list in misclassified:
    for pc in list:
        np.save('/home/dim26fa/coding/smlm_v2/models/11082022_pointnet_symmetry_6fold8fold/misclassified' + str(k) + '.npy', pc)
        k += 1

In [None]:
len(misclassified[180])

In [None]:
len(misclassified[77])

In [None]:
k = 0
for pc in misclassified[77]:
    np.save('/home/dim26fa/coding/smlm_v2/models/11082022_pointnet_symmetry_6fold8fold/misclassified' + str(k) + '.npy', pc)
    k += 1

In [None]:
np.save('/home/dim26fa/coding/smlm_v2/models/11082022_pointnet_symmetry_6fold8fold/misclassified.npy', misclassified)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(accuracies, color='#00B6CA')

In [None]:
plt.plot(losses, color='#00B6CA')

### NPC data from Heydarian et al. 

In [None]:
import scipy.io
import torch

In [None]:
from helpers.visualization import Visualization3D
from helpers.readers import Reader

In [None]:
reader = Reader('/home/dim26fa/data/imod_models/symmetry/npc_6fold/preprocessed/train/')

In [None]:
reader.get_folders()
reader.get_files_from_folder(145)
reader.filter('.csv')

In [None]:
reader.set_file(0)
reader.read_csv()

In [None]:
reader.extract_xyz(column_names=['x','y','z'])

In [None]:
matfile['particles'][0][0][0][0][0][0]

In [None]:
particles = matfile['particles'][0]

In [None]:
particle1 = particles[200]
df = pd.DataFrame(particle1[0]['coords'][0], columns = ['x','y','z','p1','p2','p3','p4','p5','p6', 'p7'])
df_xyz = df[['x','y','z']]
np.array(df_xyz)

In [None]:
len(df_xyz)

In [None]:
viz = Visualization3D(reader.df_xyz).get_3d_scatter(size=0.5).show()

In [None]:
tens = torch.tensor(np.array(reader.df_xyz))

In [None]:
knngr = knn_graph(x=tens, k=25)

In [None]:
knngr

In [None]:
from sklearn.neighbors import KNeighborsTransformer
import networkx as nx
from pyvis.network import Network

In [None]:
transformer = KNeighborsTransformer(n_neighbors = 25, algorithm='ball_tree')
transformer.fit_transform(reader.df_xyz)

In [None]:
graph = transformer.kneighbors_graph()
nx_graph = nx.from_numpy_array(graph.toarray())

In [None]:
net = Network(notebook = True)
net.show_buttons(filter_=['physics'])
net.from_nx(nx_graph)
net.show('graph.html')

In [None]:
for i in range(len(particles)):
    particle = particles[i]
    df = pd.DataFrame(particle[0]['coords'][0], columns = ['x','y','z','p1','p2','p3','p4','p5','p6', 'p7'])
    df_xyz = df[['x','y','z']]
    df_xyz.to_csv('/home/dim26fa/data/NPC_Heydarian/npc_' + str(i) + '.csv', index=False)

In [None]:
import os
import shutil

In [None]:
for i in range(51,256):
    os.mkdir('/home/dim26fa/data/imod_models/binary/npc/train/sample_' + str(i))
    shutil.move('/home/dim26fa/data/imod_models/binary/npc/train/npc_Localizations_' + str(i) + '.csv', '/home/dim26fa/data/imod_models/binary/npc/train/sample_' + str(i))

In [None]:
from helpers.visualization import Visualization3D

In [None]:
df = pd.DataFrame(misclassified[77][1].numpy(), columns = ['x','y','z'])

In [None]:
viz = Visualization3D(df).get_3d_scatter(size=1).show()

In [None]:
from helpers.readers import Reader

In [None]:
reader = Reader('/home/dim26fa/data/imod_models/symmetry/npc_8fold/preprocessed/test/')

In [None]:
reader.get_folders()

In [None]:
reader.get_files_from_folder(11)
reader.filter('.csv')
reader.set_file(0)
df = reader.read_csv()
viz = Visualization3D(df).get_3d_scatter(size=0.5).show()