In [1]:
import os
import numpy as np
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.optim as optim
from se3_transformer_pytorch import SE3Transformer


  Jd = torch.load(str(path))


# Custom HDF5 Dataset Class

In [2]:
class HDF5Dataset(Dataset):
    def __init__(self, h5_file, transform=None):
        self.h5_file = h5_file
        self.transform = transform
        with h5py.File(self.h5_file, 'r') as f:
            self.coords = f['coords'][:]
            self.features = f['features'][:]
            self.labels = f['labels'][:]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        coords = self.coords[idx]
        features = self.features[idx]
        label = self.labels[idx]
        if self.transform:
            coords, features = self.transform(coords, features)
        return torch.tensor(coords, dtype=torch.float32), torch.tensor(features, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# Example Transform Function

In [3]:
def preprocess_data(coords, features):
    # Ensure standard deviation is not zero to avoid division by zero
    if np.std(features) != 0:
        features = (features - np.mean(features)) / np.std(features)
    else:
        features = features - np.mean(features)
    
    # Normalize coordinates
    coords = coords - coords.min(axis=0)  # Shift coordinates to start from 0
    coords = coords / coords.max(axis=0)  # Normalize to range [0, 1]
    
    return coords, features

# Model Definition

In [4]:
class Image3DClassifier(nn.Module):
    def __init__(self, num_classes):
        super(Image3DClassifier, self).__init__()
        self.se3_transformer = SE3Transformer(
            num_tokens=1,  # Assuming 3D image voxels are treated as tokens
            dim=64,  # Changed from 128 to 64
            depth=6,
            heads=8,
            dim_head=64,
            num_degrees=4,
            input_degrees=1,  # Assuming input has one degree (scalar features)
            output_degrees=1  # Assuming output has one degree (scalar features)
        )
        self.classifier = nn.Linear(64, num_classes)  # Changed from 128 to 64

    def forward(self, x):
        coords, features = x
        coords = coords.long()  # Convert to LongTensor
        print("Coordinates min:", coords.min().item(), "max:", coords.max().item())  # Debugging statement
        x = self.se3_transformer(coords, features)
        x = x.mean(dim=1)  # Global average pooling
        x = self.classifier(x)
        return x

model = Image3DClassifier(num_classes=3)  # Example for a dataset with 3 classes

# Training Script

In [5]:
# Set up data directory path and HDF5 file
h5_file = 'prepared_dataset_1024.h5'  # Path to your HDF5 dataset

# Create dataset and dataloader
train_dataset = HDF5Dataset(h5_file, transform=preprocess_data)
# Assuming train_dataset is already defined
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Model, loss function, optimizer
model = Image3DClassifier(num_classes=3)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):  # Number of epochs
    for i, (coords, features, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        coords = coords.long()  # Convert to LongTensor
        print("Coordinates min:", coords.min().item(), "max:", coords.max().item())  # Debugging statement
        outputs = model((coords, features))
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/10], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Save the model
torch.save(model.state_dict(), 'model.pth')

Coordinates min: 0 max: 1
Coordinates min: 0 max: 1


IndexError: index out of range in self

# Evaluation Script

In [None]:
# Create dataset and dataloader
test_dataset = HDF5Dataset(h5_file, transform=preprocess_data)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Model
model = Image3DClassifier(num_classes=3)
model.load_state_dict(torch.load('model.pth'))  # Load trained model

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for coords, features, labels in test_loader:
        outputs = model((coords, features))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total:.2f}%')
