# Rice Diseases: Image-to-Graph Conversion

This notebook converts rice disease images to graph structures using superpixel segmentation.

**Output**: Individual .pt files for each graph + zip archive (fairseq-compatible)

**Dataset**: 4 rice disease classes - BrownSpot, Healthy, Hispa, LeafBlast

## 1. Setup Environment

Mount Google Drive and setup directories.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

print("✓ Google Drive mounted successfully")

In [None]:
# Clone Graphormer repository (if not already cloned)
import os

if not os.path.exists('/content/Graphormer'):
    !git clone https://github.com/microsoft/Graphormer.git
    print("✓ Cloned Graphormer repository")
else:
    print("✓ Graphormer repository already exists")

%cd /content/Graphormer

## 2. Install Dependencies

Install PyTorch Geometric and rice_diseases specific dependencies.

In [None]:
# Install PyTorch Geometric and dependencies
# Get PyTorch version first
import torch
torch_version = torch.__version__.split('+')[0]
cuda_version = torch.version.cuda

print(f"PyTorch version: {torch_version}")
print(f"CUDA version: {cuda_version}")
print("\nInstalling PyTorch Geometric...")

# Install PyG (version 1.7.2 compatible with most PyTorch versions)
!pip install -q torch-geometric==1.7.2

print("\n" + "=" * 60)
print("✓ PyTorch Geometric installed")
print("=" * 60)

In [None]:
# Install rice_diseases specific dependencies
!pip install -q -r examples/rice_diseases/requirements.txt

print("\n" + "=" * 60)
print("✓ All dependencies installed")
print("=" * 60)

In [None]:
# Verify installation
import sys
sys.path.append('/content/Graphormer')

from examples.rice_diseases.colab_setup import verify_installation

verify_installation()

## 3. Copy and Extract Dataset

Copy dataset from Google Drive to local Colab storage and extract.

**Note**: This avoids RAM overflow by copying to `/tmp` first before extraction.

In [None]:
from examples.rice_diseases.colab_setup import copy_and_extract_dataset

# Copy from Drive and extract
data_dir = copy_and_extract_dataset(
    drive_zip_path="MyDrive/Rice_Diseases_Dataset/rice-diseases-image-dataset.zip",
    temp_dir="/tmp",
    extract_dir="/content/rice_diseases_data"
)

print(f"\n✓ Dataset extracted to: {data_dir}")

In [None]:
# Analyze dataset structure
from examples.rice_diseases.colab_setup import get_dataset_structure

dataset_structure = get_dataset_structure(data_dir)

## 4. Process Images to Individual .pt Files

Convert all images to graphs and save each as a separate .pt file.

**This processes images one-by-one** to avoid RAM overflow. Each graph is saved immediately and deleted from memory.

In [None]:
from examples.rice_diseases.rice_diseases_dataset import RiceDiseasesDataset

# Create dataset and process all images
# This will create individual .pt files in rice_diseases_graphs/processed/
dataset = RiceDiseasesDataset(
    root="/content/rice_diseases_graphs",
    image_dir="/content/rice_diseases_data",
    split='train',  # Split doesn't matter for processing
    n_segments=75,
    force_process=True  # Set to False if already processed
)

print("\n" + "=" * 60)
print("✓ Image processing complete!")
print("=" * 60)

In [None]:
# Check processed files
import os
from pathlib import Path

processed_dir = "/content/rice_diseases_graphs/processed"
pt_files = list(Path(processed_dir).glob("data_*.pt"))
metadata_file = Path(processed_dir) / "metadata.json"
split_file = Path(processed_dir) / "split_indices.pt"

print("Processed files:")
print("-" * 60)
print(f"Number of graph files (.pt): {len(pt_files)}")
print(f"Metadata file: {'✓' if metadata_file.exists() else '✗'}")
print(f"Split indices file: {'✓' if split_file.exists() else '✗'}")

# Calculate total size
total_size = sum(f.stat().st_size for f in pt_files)
total_size_mb = total_size / (1024 * 1024)
print(f"Total size: {total_size_mb:.2f} MB")
print("-" * 60)

## 5. Create Zip Archive

Package all .pt files into a single zip archive for easy distribution.

In [None]:
from examples.rice_diseases.rice_diseases_dataset import create_dataset_zip

# Create zip archive
zip_path = create_dataset_zip(
    processed_dir="/content/rice_diseases_graphs/processed",
    output_zip_path="/content/rice_diseases_graphs.zip"
)

print(f"\n✓ Zip archive ready: {zip_path}")
print("\nYou can download this file and share it!")

## 6. Load Dataset Splits

Verify the dataset can be loaded correctly for train/val/test.

In [None]:
from examples.rice_diseases.rice_diseases_dataset import RiceDiseasesDataset, CLASS_NAMES

# Load each split
train_dataset = RiceDiseasesDataset(root="/content/rice_diseases_graphs", split='train')
val_dataset = RiceDiseasesDataset(root="/content/rice_diseases_graphs", split='val')
test_dataset = RiceDiseasesDataset(root="/content/rice_diseases_graphs", split='test')

print("Dataset splits:")
print("-" * 60)
print(f"Train: {len(train_dataset)} samples")
print(f"Val:   {len(val_dataset)} samples")
print(f"Test:  {len(test_dataset)} samples")
print("-" * 60)

# Inspect a sample
sample = train_dataset[0]
print("\nSample graph structure:")
print(f"  Nodes: {sample.x.shape[0]}")
print(f"  Edges: {sample.edge_index.shape[1]}")
print(f"  Node features: {sample.x.shape} (RGB color)")
print(f"  Edge features: {sample.edge_attr.shape} (color difference)")
print(f"  Label: {sample.y.item()} ({CLASS_NAMES[sample.y.item()]})")

## 7. Generate Visualizations

Create sample visualizations showing the image-to-graph conversion process.

In [None]:
# Load metadata to get image paths
import json

with open("/content/rice_diseases_graphs/processed/metadata.json", 'r') as f:
    metadata = json.load(f)

print(f"Total processed images: {metadata['num_graphs']}")
print(f"Classes: {metadata['class_names']}")

In [None]:
# Create visualization for a few samples
from examples.rice_diseases.visualize_graphs import visualize_image_to_graph
from examples.rice_diseases.rice_image_to_graph import ImageToGraphConverter
from IPython.display import Image, display
import matplotlib.pyplot as plt

converter = ImageToGraphConverter(n_segments=75)
viz_dir = "/content/rice_diseases_visualizations"

# Visualize 2 samples per class
samples_per_class = 2
viz_files = []

for class_name in CLASS_NAMES:
    class_idx = CLASS_NAMES.index(class_name)
    # Find images of this class
    class_images = [(i, path) for i, (path, label) in enumerate(zip(metadata['image_paths'], metadata['labels'])) if label == class_idx]
    
    for j, (idx, img_path) in enumerate(class_images[:samples_per_class]):
        save_path = f"{viz_dir}/{class_name}_sample_{j+1}.png"
        fig = visualize_image_to_graph(img_path, converter, save_path=save_path)
        viz_files.append(save_path)
        plt.close(fig)

print(f"\n✓ Created {len(viz_files)} visualizations in {viz_dir}")

In [None]:
# Display some visualizations
print("Sample visualizations:")
print("=" * 60)

for viz_file in viz_files[:4]:
    if os.path.exists(viz_file):
        print(f"\n{os.path.basename(viz_file)}:")
        display(Image(filename=viz_file, width=800))

## 8. Summary

### What Was Created

1. **Individual .pt files**: Each graph saved separately in `/content/rice_diseases_graphs/processed/`
2. **Split indices**: Train/val/test splits in `split_indices.pt`
3. **Metadata**: Dataset information in `metadata.json`
4. **Zip archive**: All files packaged in `rice_diseases_graphs.zip`
5. **Visualizations**: Sample conversions in `/content/rice_diseases_visualizations/`

### Dataset Format (PyTorch Geometric)

```
rice_diseases_graphs/
  processed/
    data_0.pt
    data_1.pt
    ...
    split_indices.pt
    metadata.json
```

### Next Steps: Training with Graphormer

The dataset is now ready for training! Use the training script:

```bash
cd /content/Graphormer/examples/rice_diseases
bash rice_diseases.sh
```

Or run fairseq-train directly:

```bash
fairseq-train \
  --user-dir ../../graphormer \
  --dataset-name rice_diseases \
  --dataset-source pyg \
  --task graph_prediction \
  --criterion multiclass_cross_entropy \
  --num-classes 4 \
  --batch-size 32 \
  ...
```

In [None]:
print("=" * 80)
print("                    DATA CONVERSION COMPLETE!")
print("=" * 80)
print(f"\nProcessed graphs: /content/rice_diseases_graphs/processed/")
print(f"Zip archive: /content/rice_diseases_graphs.zip")
print(f"Visualizations: /content/rice_diseases_visualizations/")
print(f"\nDataset is ready for Graphormer training!")
print(f"Run: bash examples/rice_diseases/rice_diseases.sh")
print("=" * 80)