In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from torchvision.models import convnext_base
from torch import optim
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from skimage.segmentation import slic
import torch.nn.functional as F



In [2]:
# Data Generator Setup
train_datagenerator = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

test_datagenerator = ImageDataGenerator(rescale=1./255)

train_data = train_datagenerator.flow_from_directory(
    '/Users/anantsinha/Downloads/Cotton Disease/train',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

test_data = test_datagenerator.flow_from_directory(
    '/Users/anantsinha/Downloads/Cotton Disease/val',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)



Found 1951 images belonging to 4 classes.
Found 253 images belonging to 4 classes.


In [3]:
def generate_superpixels(image):
    """Generates superpixels from an image."""
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    
    if len(image.shape) == 3 and image.shape[0] == 3:
        image = np.transpose(image, (1, 2, 0))
    
    if image.max() > 1.0:
        image = image / 255.0
    
    segments = slic(image, n_segments=100, compactness=10, sigma=1)
    return segments

def calculate_features(image, segments):
    """Calculates features for each superpixel."""
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    
    if len(image.shape) == 3 and image.shape[0] == 3:
        image = np.transpose(image, (1, 2, 0))
    
    num_superpixels = len(np.unique(segments))
    features = []
    
    for i in range(num_superpixels):
        mask = segments == i
        if mask.sum() > 0:
            avg_color = np.mean(image[mask], axis=0)
            features.append(avg_color)
        else:
            features.append(np.zeros(3))
    
    return np.array(features)

def create_graph_from_superpixels_pyg(image, segments, max_segments=100):
    """Creates a PyTorch Geometric graph from superpixels."""
    if len(unique_labels := np.unique(segments)) > max_segments:
        from skimage.segmentation import join_segmentations
        segments = (join_segmentations(segments, np.zeros_like(segments)) % max_segments)
    
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    if len(image.shape) == 3 and image.shape[0] == 3:
        image = np.transpose(image, (1, 2, 0))
    
    # Calculate features
    features = np.zeros((max_segments, 3))
    for i in range(max_segments):
        mask = segments == i
        if mask.sum() > 0:
            features[i] = np.mean(image[mask], axis=0)
    
    # Create edges
    edges = set()
    height, width = segments.shape
    for i in range(height):
        for j in range(width):
            current_label = segments[i, j]
            for di, dj in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
                ni, nj = i + di, j + dj
                if 0 <= ni < height and 0 <= nj < width:
                    neighbor_label = segments[ni, nj]
                    if current_label != neighbor_label:
                        edges.add((int(current_label), int(neighbor_label)))
    
    edge_index = torch.tensor(list(edges), dtype=torch.long).t().contiguous()
    if edge_index.numel() == 0:
        edge_index = torch.arange(max_segments).repeat(2, 1)
    
    node_features = torch.tensor(features, dtype=torch.float)
    return Data(x=node_features, edge_index=edge_index)



In [4]:
class ConvNeXtGCNModel(nn.Module):
    def __init__(self, num_classes=4, hidden_dim=128, max_segments=100):
        super().__init__()
        self.max_segments = max_segments
        
        # ConvNeXt backbone
        self.convnext = convnext_base(pretrained=True)
        self.convnext = nn.Sequential(*list(self.convnext.children())[:-1])
        
        # Feature processing
        self.dim_reduction = nn.Linear(1024, hidden_dim)
        
        # GCN layers
        self.gcn1 = GCNConv(hidden_dim + 3, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        
        # Final classification layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
        
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, images, graph_data):
        batch_size = images.size(0)
        
        # Process images through ConvNeXt
        if images.size(1) != 3:
            images = images.permute(0, 3, 1, 2)
        x = self.convnext(images)
        x = x.view(batch_size, -1)
        x = self.dim_reduction(x)
        
        # Expand CNN features
        x = x.unsqueeze(1).expand(-1, self.max_segments, -1)
        x = x.reshape(-1, x.size(-1))
        
        # Combine with graph features
        graph_x = torch.cat([graph_data.x, x], dim=1)
        
        # GCN processing
        x = self.gcn1(graph_x, graph_data.edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.gcn2(x, graph_data.edge_index)
        x = F.relu(x)
        
        # Reshape and pool
        x = x.view(batch_size, self.max_segments, -1)
        x = torch.mean(x, dim=1)
        
        # Classification
        x = self.classifier(x)
        return x



In [12]:
def train_model(model, train_data, optimizer, criterion, device, num_epochs=10):
    model = model.to(device)
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (images, labels) in enumerate(train_data):
            try:
                # Convert numpy arrays to tensors and handle data types
                images = torch.from_numpy(images).float().to(device)
                
                # Convert one-hot encoded labels to class indices
                if len(labels.shape) > 1:  # If labels are one-hot encoded
                    labels = np.argmax(labels, axis=1)
                labels = torch.from_numpy(labels).long().to(device)
                
                # Generate graphs
                graphs = []
                for img in images:
                    segments = generate_superpixels(img.cpu())
                    graph_data = create_graph_from_superpixels_pyg(img.cpu(), segments)
                    graphs.append(graph_data)
                
                batch_graph = Batch.from_data_list(graphs).to(device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = model(images, batch_graph)
                
                # CrossEntropyLoss expects class indices, not one-hot encoded labels
                loss = criterion(outputs, labels)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Statistics
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                if batch_idx % 10 == 9:
                    accuracy = 100 * correct / total
                    print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, '
                          f'Loss: {running_loss/10:.4f}, Accuracy: {accuracy:.2f}%')
                    print(f'Predictions: {predicted}')
                    print(f'True labels: {labels}')
                    running_loss = 0.0
                    correct = 0
                    total = 0
                    
            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                import traceback
                print(traceback.format_exc())
                continue

In [13]:
criterion = nn.CrossEntropyLoss()

# Initialize model with correct number of classes
num_classes = len(train_data.class_indices)
print(f"Number of classes: {num_classes}")

model = ConvNeXtGCNModel(num_classes=num_classes, hidden_dim=128, max_segments=100)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Train the model
train_model(model, train_data, optimizer, criterion, device)

Number of classes: 4
Epoch: 1, Batch: 10, Loss: 1.3791, Accuracy: 35.31%
Predictions: tensor([1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 0, 1, 1, 1, 3, 3, 1, 3, 1, 3, 3,
        3, 3, 3, 3, 1, 3, 1, 3])
True labels: tensor([0, 2, 3, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 0, 0, 2, 1, 2, 1, 3, 2, 1, 1, 1,
        2, 1, 2, 1, 2, 1, 2, 1])
Epoch: 1, Batch: 20, Loss: 1.2734, Accuracy: 39.69%
Predictions: tensor([1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 1, 1, 1, 1, 1, 1, 1])
True labels: tensor([0, 3, 0, 1, 0, 3, 3, 1, 1, 3, 1, 1, 1, 3, 3, 3, 1, 2, 1, 0, 1, 2, 3, 0,
        0, 1, 1, 1, 1, 1, 3, 1])
Epoch: 1, Batch: 30, Loss: 1.1561, Accuracy: 45.00%
Predictions: tensor([1, 2, 0, 2, 0, 1, 0, 2, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 2,
        1, 2, 1, 1, 0, 0, 1, 3])
True labels: tensor([1, 2, 2, 2, 2, 1, 3, 2, 1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 1, 3, 1, 1, 1, 2,
        1, 2, 3, 1, 2, 0, 3, 1])
Epoch: 1, Batch: 40, Loss: 1.1479, Accuracy: 47.81%
Predictions: te

KeyboardInterrupt: 

Error in batch 0: Expected floating point type for target with class probabilities, got Long
Error in batch 1: Expected floating point type for target with class probabilities, got Long


KeyboardInterrupt: 