In [None]:
import torch
import os 
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt 

In [None]:
def visualize_points(pos, edge_index=None, index=None):
    fig = plt.figure(figsize=(4, 4))
    if edge_index is not None:
        for (src, dst) in edge_index.t().tolist():
             src = pos[src].tolist()
             dst = pos[dst].tolist()
             plt.plot([src[0], dst[0]], [src[1], dst[1]], linewidth=1, color='black')
    if index is None:
        plt.scatter(pos[:, 0], pos[:, 1], s=50, zorder=1000)
    else:
       mask = torch.zeros(pos.size(0), dtype=torch.bool)
       mask[index] = True
       plt.scatter(pos[~mask, 0], pos[~mask, 1], s=50, color='lightgray', zorder=1000)
       plt.scatter(pos[mask, 0], pos[mask, 1], s=50, zorder=1000)
    plt.axis('off')
    plt.show()

In [None]:
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

In [None]:
from torch_geometric.datasets import GeometricShapes
from torch_geometric.transforms import SamplePoints

In [None]:
dataset = GeometricShapes(root='GeometricShapes')

In [None]:
dataset.num_classes

In [None]:
dataset.transform = SamplePoints(num=128)

In [None]:
data = dataset[0]

In [None]:
print(data.edge_index)

In [None]:
print(data)

In [None]:
from helpers.visualization import Visualization

In [None]:
type(data)

In [None]:
df = pd.DataFrame(data.pos, columns=['x','y','z'])

In [None]:
viz = Visualization(df).get_3d_scatter().show()

In [None]:
from torch_geometric.transforms import SamplePoints

In [None]:
from torch_cluster import knn_graph

In [None]:
data.edge_index = knn_graph(data.pos, k = 6)

In [None]:
visualize_points(data.pos, edge_index=data.edge_index)

In [None]:
from helpers.readers import Reader
import os
import glob
reader = Reader('/home/dim26fa/data/imod_models/')
reader.get_folders('0')


In [None]:
reader.get_files_from_folder(0)
reader.folder

In [None]:
files = np.empty(10, dtype='object')

In [None]:
for file in glob.glob('/home/dim26fa/data/imod_models/model_0/**/*Localizations.txt'):
    #data = pd.read_csv()
    files.append(pd.read_csv(file))

In [None]:
path = '/home/dim26fa/data/imod_models/model_1'
for folder in os.listdir(path):
    fpath = os.path.join(path,folder)
    if os.path.isdir(fpath):
        reader = Reader(fpath)
        reader.get_files_from_folder(path = fpath)
        reader.filter('Localizations.txt')
        reader.set_file(0)
        reader.read_txt(4)
        f = reader.extract_xyz(column_names=[0,1,2]).to_numpy()
        tensor = torch.tensor(f)      
        

In [None]:
for folder in os.listdir('/home/dim26fa/data/imod_models/model_1'):
    if os.path.isdir(folder):
        print('0')

In [None]:
data = pd.read_csv('/home/dim26fa/data/imod_models/model_1/sample_0/adfl-r-cutLocalizations.txt', skiprows=[0])
data = data[data.columns[0]].str.split(' ', 4, expand=True)
data

### PointNet 

In [None]:
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing

In [None]:
class PointNetLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max')
        self.mlp = Sequential(Linear(in_channels + 3, out_channels), 
                              ReLU(),
                              Linear(out_channels, out_channels))
    
    def forward(self, h, pos, edge_index):
        return self.propagate(edge_index, h=h, pos=pos)
    
    def message(self, h_j, pos_j, pos_i):
        inputt = pos_j - pos_i
        
        if h_j is not None:
            
            inputt = torch.cat([h_j, inputt], dim=-1)
        
        return self.mlp(inputt)
            
            
            

In [None]:
import torch.nn.functional as F
from torch_cluster import knn_graph
from torch_geometric.nn import global_max_pool

class PointNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        torch.manual_seed(12345)
        self.conv1 = PointNetLayer(3, 32)
        self.conv2 = PointNetLayer(32, 32)
        self.classifier = Linear(32, dataset.num_classes)
        
    def forward(self, pos, batch):
        edge_index = knn_graph(pos, k=16, batch = batch, loop=True)
        
        h = self.conv1(h=pos, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = self.conv2(h=h, pos=pos, edge_index=edge_index)
        h = h.relu()
        
        h = global_max_pool(h, batch)
        
        return self.classifier(h)

model = PointNet()
print(model)

In [None]:
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0,true, {maxHeight: 300})'''))

from torch_geometric.loader import DataLoader

train_dataset = GeometricShapes(root='data/GeometricShapes', train = True, 
                               transform = SamplePoints(128))
test_dataset = GeometricShapes(root='data/GeometricShapes', train = False, 
                               transform = SamplePoints(128))

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10)

model = PointNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()


In [None]:
??GeometricShapes

In [None]:
class SMLMShapes():
    r"""Synthetic dataset of various geometric shapes like cubes, spheres or
    pyramids.

    .. note::

        Data objects hold mesh faces instead of edge indices.
        To convert the mesh to a graph, use the
        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
        To convert the mesh to a point cloud, use the
        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
        sample a fixed number of points on the mesh faces according to their
        face area.

    Args:
        root (string): Root directory where the dataset should be saved.
        train (bool, optional): If :obj:`True`, loads the training dataset,
            otherwise the test dataset. (default: :obj:`True`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """

    def __init__(self, root: str, path: str,
                 train: bool = True,
                 transform: Optional = None,
                 pre_transform: Optional = None,
                 pre_filter: Optional = None):
        super().__init__(root, transform, pre_transform, pre_filter)
        path = self.path[0] if train else self.path[1]
        self.data = torch.load(path)

    def process(self):
        torch.save(self.process_set('train'), self.path[0])
        torch.save(self.process_set('test'), self.path[1])

    def process_set(self, dataset: str):
        categories = glob.glob(osp.join(self.raw_dir, '*', ''))
        categories = sorted([x.split(os.sep)[-2] for x in categories])

        data_list = []
        for target, category in enumerate(categories):
            folder = os.path.join(self.raw_dir, category, dataset)
            paths = glob.glob(f'{folder}/*.txt')
            for path in paths:
                data = read_off(path)
                data.pos = data.pos - data.pos.mean(dim=0, keepdim=True)
                data.y = torch.tensor([target])
                data_list.append(data)

        return self.collate(data_list)

In [None]:
import glob
categories = glob.glob('/home/dim26fa/data/imod_models/*')
categories = sorted([x.split(os.sep)[-1] for x in categories])
categories
for target, category in enumerate(categories):
    print(target, category)

In [None]:
categories[1].split(os.sep)[-1]

In [None]:

def train(model, optimizer, loader):
    model.train()
    
    total_loss = 0
    for data in loader: 
        optimizer.zero_grad()
        logits = model(data.pos, data.batch)
        loss = criterion(logits, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs 
    
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test(model, loader):
    model.eval()
    
    total_correct = 0 
    for data in loader:
        logits = model(data.pos, data.batch)
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())
        
    return total_correct / len(loader.dataset)
accuracies = []
for epoch in range(1,101):
    loss = train(model, optimizer, train_loader)
    test_acc = test(model, test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test accuracy: {test_acc: .4f}')
    accuracies.append(test_acc)

In [None]:
x = range(1,101)
plt.plot(x, accuracies)