In [2]:
# Step 1: Install required packages and mount Drive
!pip install roboflow torch torchvision matplotlib seaborn scikit-learn

import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import shutil
import yaml
from google.colab import drive

# Mount Google Drive FIRST
drive.mount('/content/drive')

# Create a dedicated folder for your tongue disease project
project_folder = '/content/drive/MyDrive/Tongue_Disease_AI'
os.makedirs(project_folder, exist_ok=True)
print(f"✅ Project folder created at: {project_folder}")

# Set dataset paths in Drive
raw_dataset_path = os.path.join(project_folder, 'raw_dataset')
processed_dataset_path = os.path.join(project_folder, 'processed_dataset')
models_folder = os.path.join(project_folder, 'models')

# Create necessary folders
os.makedirs(raw_dataset_path, exist_ok=True)
os.makedirs(processed_dataset_path, exist_ok=True)
os.makedirs(models_folder, exist_ok=True)

print("✅ All Drive folders created!")


Collecting roboflow
  Downloading roboflow-1.2.3-py3-none-any.whl.metadata (9.7 kB)
Collecting idna==3.7 (from roboflow)
  Downloading idna-3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting opencv-python-headless==4.10.0.84 (from roboflow)
  Downloading opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting pi-heif<2 (from roboflow)
  Downloading pi_heif-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.5 kB)
Collecting pillow-avif-plugin<2 (from roboflow)
  Downloading pillow_avif_plugin-1.5.2-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (2.1 kB)
Collecting python-dotenv (from roboflow)
  Downloading python_dotenv-1.1.1-py3-none-any.whl.metadata (24 kB)
Collecting filetype (from roboflow)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.meta

In [3]:
# Step 2: Download Dataset with CORRECT Format
from roboflow import Roboflow

# Your API key
api_key = "8Busg0IT41XLrpXmDMDL"

print("🔐 Using your Roboflow API key...")
print("📦 Downloading tongue disease dataset to Google Drive...")

# Change to the Drive directory before downloading
os.chdir(raw_dataset_path)

# Initialize Roboflow
rf = Roboflow(api_key=api_key)

# Download with CORRECT format for multilabel-classification
project = rf.workspace("medical-wmypr").project("tongue-tod5c")
dataset = project.version(1).download("folder")  # Changed from "yolov8" to "folder"

print("✅ Dataset downloaded successfully to Google Drive!")

# The dataset will be at: /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/tongue-tod5c-1/
drive_dataset_path = os.path.join(raw_dataset_path, "tongue-tod5c-1")

print(f"📂 Dataset location in Drive: {drive_dataset_path}")


🔐 Using your Roboflow API key...
📦 Downloading tongue disease dataset to Google Drive...
loading Roboflow workspace...
loading Roboflow project...
✅ Dataset downloaded successfully to Google Drive!
📂 Dataset location in Drive: /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/tongue-tod5c-1


In [4]:
# Explore the downloaded dataset structure in Google Drive
import os

drive_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/tongue-tod5c-1'

print("📊 Exploring downloaded dataset structure...")
print("=" * 50)

# Check what's in the main dataset folder
print(f"📂 Contents of: {drive_dataset_path}")
if os.path.exists(drive_dataset_path):
    main_contents = os.listdir(drive_dataset_path)
    for item in main_contents:
        item_path = os.path.join(drive_dataset_path, item)
        if os.path.isdir(item_path):
            file_count = len(os.listdir(item_path))
            print(f"  📁 {item}/: {file_count} items")
        else:
            print(f"  📄 {item}")
else:
    print("❌ Dataset path not found")

# Check for train/valid/test splits
splits = ['train', 'valid', 'test']
total_images = 0

for split in splits:
    split_path = os.path.join(drive_dataset_path, split)
    if os.path.exists(split_path):
        print(f"\n📁 {split.upper()} SET:")
        split_total = 0

        # Check classes in this split
        for class_folder in sorted(os.listdir(split_path)):
            class_path = os.path.join(split_path, class_folder)
            if os.path.isdir(class_path):
                image_count = len([f for f in os.listdir(class_path)
                                 if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))])
                if image_count > 0:  # Only show folders with images
                    split_total += image_count
                    print(f"  • {class_folder}: {image_count} images")

        total_images += split_total
        print(f"  📊 {split} subtotal: {split_total} images")

print(f"\n🎯 TOTAL IMAGES: {total_images}")
print(f"💾 Dataset successfully stored in Google Drive!")


📊 Exploring downloaded dataset structure...
📂 Contents of: /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/tongue-tod5c-1
❌ Dataset path not found

🎯 TOTAL IMAGES: 0
💾 Dataset successfully stored in Google Drive!


In [5]:
# Find where the dataset actually got downloaded
import os

raw_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset'

print("🔍 Searching for the actual dataset location...")
print("=" * 50)

# Check what's actually in the raw_dataset folder
print(f"📂 Contents of raw_dataset folder:")
if os.path.exists(raw_dataset_path):
    contents = os.listdir(raw_dataset_path)
    print(f"Found {len(contents)} items:")

    for item in contents:
        item_path = os.path.join(raw_dataset_path, item)
        if os.path.isdir(item_path):
            subitem_count = len(os.listdir(item_path))
            print(f"  📁 {item}/: {subitem_count} items")

            # Check if this looks like our dataset
            if subitem_count > 100:  # Likely our dataset with 10k+ files
                print(f"    🎯 This is likely our dataset!")
                actual_dataset_path = item_path

                # Quick peek inside
                subitems = os.listdir(item_path)[:10]  # First 10 items
                print(f"    📋 Contents: {subitems}")
        else:
            print(f"  📄 {item}")
else:
    print("❌ Raw dataset folder not found")

# Also check the current working directory (where Roboflow downloads by default)
print(f"\n🔍 Also checking current directory: {os.getcwd()}")
current_contents = [item for item in os.listdir('.') if 'tongue' in item.lower() or 'Tongue' in item]
if current_contents:
    print(f"Found tongue-related folders in current directory: {current_contents}")


🔍 Searching for the actual dataset location...
📂 Contents of raw_dataset folder:
Found 1 items:
  📁 Tongue-1/: 5 items

🔍 Also checking current directory: /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset
Found tongue-related folders in current directory: ['Tongue-1']


In [6]:
# Explore the actual dataset structure inside Tongue-1 folder
import os

# Update the correct dataset path
actual_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1'

print("📊 Exploring Tongue-1 dataset structure...")
print("=" * 50)
print(f"📂 Dataset location: {actual_dataset_path}")

# Check the 5 items inside Tongue-1
print(f"\n📋 Contents of Tongue-1:")
for item in os.listdir(actual_dataset_path):
    item_path = os.path.join(actual_dataset_path, item)
    if os.path.isdir(item_path):
        file_count = len(os.listdir(item_path))
        print(f"  📁 {item}/: {file_count} items")

        # If it's a folder with many items, peek inside
        if file_count > 10:
            subitems = sorted(os.listdir(item_path))[:5]  # First 5 items
            print(f"    📋 Sample contents: {subitems}")

            # Check if these are class folders or image files
            first_item_path = os.path.join(item_path, subitems[0])
            if os.path.isdir(first_item_path):
                print(f"    📁 Contains subfolders (likely class folders)")
            else:
                print(f"    📄 Contains files directly")
    else:
        print(f"  📄 {item}")

# Check for typical dataset splits
splits_to_check = ['train', 'valid', 'test', 'training', 'validation', 'testing']
found_splits = []

for split in splits_to_check:
    split_path = os.path.join(actual_dataset_path, split)
    if os.path.exists(split_path):
        found_splits.append(split)

if found_splits:
    print(f"\n✅ Found dataset splits: {found_splits}")
else:
    print(f"\n⚠️ No standard train/valid/test splits found")

print(f"\n🎯 Ready to analyze the dataset structure!")


📊 Exploring Tongue-1 dataset structure...
📂 Dataset location: /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1

📋 Contents of Tongue-1:
  📄 README.dataset.txt
  📄 README.roboflow.txt
  📁 test/: 116 items
    📋 Sample contents: ['colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping', 'colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis', 'colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_greasy', 'colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_thin', 'colorResult_grey shapeResult_ToothMarks textureResult_normal thicknessResult_greasy']
    📁 Contains subfolders (likely class folders)
  📁 train/: 166 items
    📋 Sample contents: ['colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping', 'colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis', 'colorResult_grey shapeResult_ToothMarks textureResult_

In [7]:
# Count actual images in each split and analyze the class combinations
import os

actual_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1'

print("📊 Analyzing tongue disease dataset - Image counts per split...")
print("=" * 60)

total_images = 0
all_classes = set()

for split in ['train', 'valid', 'test']:
    split_path = os.path.join(actual_dataset_path, split)
    split_images = 0
    split_classes = 0

    print(f"\n📁 {split.upper()} SET:")
    print("-" * 30)

    class_folders = sorted(os.listdir(split_path))

    for class_folder in class_folders:
        class_path = os.path.join(split_path, class_folder)
        if os.path.isdir(class_path):
            # Count images in this class
            image_files = [f for f in os.listdir(class_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
            image_count = len(image_files)

            if image_count > 0:
                split_images += image_count
                split_classes += 1
                all_classes.add(class_folder)

                # Show first few classes as examples
                if split_classes <= 3:
                    print(f"  • {class_folder}: {image_count} images")

    # Show summary for this split
    if split_classes > 3:
        print(f"  ... and {split_classes - 3} more classes")

    print(f"  📊 {split} Summary: {split_images} images across {split_classes} classes")
    total_images += split_images

print(f"\n🎯 DATASET SUMMARY:")
print("=" * 30)
print(f"📈 Total Images: {total_images}")
print(f"🏷️ Total Unique Classes: {len(all_classes)}")
print(f"📊 Perfect for multi-attribute tongue diagnosis!")

# Analyze the attribute combinations
print(f"\n🔍 Sample class combinations (first 5):")
sample_classes = list(all_classes)[:5]
for i, class_name in enumerate(sample_classes, 1):
    print(f"{i}. {class_name}")

print(f"\n✅ Dataset ready for advanced tongue disease classification!")


📊 Analyzing tongue disease dataset - Image counts per split...

📁 TRAIN SET:
------------------------------
  • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping: 4 images
  • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis: 43 images
  • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_greasy: 34 images
  ... and 163 more classes
  📊 train Summary: 7929 images across 166 classes

📁 VALID SET:
------------------------------
  • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping: 1 images
  • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis: 3 images
  • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_greasy: 3 images
  ... and 117 more classes
  📊 valid Summary: 973 images across 120 classes

📁 TEST SET:
------------------------------
  • colorResult_grey shapeResult_ToothMarks textureResult_

In [4]:
# Setup training pipeline for the multi-attribute tongue disease dataset
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
import torch.optim as optim

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Enhanced data transformations for medical images
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.2),  # Conservative for medical data
    transforms.RandomRotation(5),  # Small rotation for tongues
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets from Google Drive
dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1'
train_path = os.path.join(dataset_path, 'train')
valid_path = os.path.join(dataset_path, 'valid')
test_path = os.path.join(dataset_path, 'test')

print("📚 Loading multi-attribute tongue disease dataset...")

# Load datasets
train_dataset = datasets.ImageFolder(root=train_path, transform=transform_train)
val_dataset = datasets.ImageFolder(root=valid_path, transform=transform_val)
test_dataset = datasets.ImageFolder(root=test_path, transform=transform_val)

# Create data loaders
batch_size = 32  # Good for this dataset size
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# Get dataset information
num_classes = len(train_dataset.classes)
class_names = train_dataset.classes

print(f"✅ Professional dataset loaded successfully!")
print(f"📈 Training samples: {len(train_dataset)}")
print(f"📊 Validation samples: {len(val_dataset)}")
print(f"🧪 Test samples: {len(test_dataset)}")
print(f"🏷️ Number of classes: {num_classes}")
print(f"📝 Multi-attribute tongue diagnosis ready!")

# Show sample class names to verify
print(f"\n🔍 Sample classes (first 3):")
for i, class_name in enumerate(class_names[:3]):
    print(f"{i+1}. {class_name}")

print(f"\n🎯 Ready for DenseNet-121 model initialization!")


🚀 Using device: cuda
GPU: Tesla T4
GPU Memory: 15.8 GB
📚 Loading multi-attribute tongue disease dataset...
✅ Professional dataset loaded successfully!
📈 Training samples: 7929
📊 Validation samples: 973
🧪 Test samples: 988
🏷️ Number of classes: 166
📝 Multi-attribute tongue diagnosis ready!

🔍 Sample classes (first 3):
1. colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping
2. colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis
3. colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_greasy

🎯 Ready for DenseNet-121 model initialization!


In [5]:
# Initialize DenseNet-121 model for 166-class multi-attribute tongue diagnosis
class TongueDiseaseClassifier(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(TongueDiseaseClassifier, self).__init__()

        # Load pre-trained DenseNet-121
        self.densenet = models.densenet121(pretrained=pretrained)

        # Get number of features from the original classifier
        num_features = self.densenet.classifier.in_features

        # Custom classifier for multi-attribute tongue diagnosis
        self.densenet.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 1024),  # Larger layer for complex classification
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)  # 166 classes
        )

    def forward(self, x):
        return self.densenet(x)

# Initialize model
print("🤖 Initializing DenseNet-121 for multi-attribute tongue diagnosis...")

model = TongueDiseaseClassifier(num_classes=num_classes, pretrained=True)
model = model.to(device)

# Training components optimized for multi-class medical data
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=7, verbose=True)

# Model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✅ Multi-attribute classifier initialized successfully!")
print(f"📊 Total parameters: {total_params:,}")
print(f"🎯 Trainable parameters: {trainable_params:,}")
print(f"🏗️ Architecture: DenseNet-121 → 1024 → 512 → {num_classes} classes")
print(f"🧠 Capable of predicting: Color + Shape + Texture + Thickness")
print(f"💪 Model ready for training on {len(train_dataset)} medical images!")

print(f"\n🚀 Ready for training! This will be a sophisticated medical AI system!")


🤖 Initializing DenseNet-121 for multi-attribute tongue diagnosis...


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 146MB/s]


✅ Multi-attribute classifier initialized successfully!
📊 Total parameters: 8,613,414
🎯 Trainable parameters: 8,613,414
🏗️ Architecture: DenseNet-121 → 1024 → 512 → 166 classes
🧠 Capable of predicting: Color + Shape + Texture + Thickness
💪 Model ready for training on 7929 medical images!

🚀 Ready for training! This will be a sophisticated medical AI system!




In [6]:
# Create simplified dataset using the color_shape strategy (12 classes - optimal balance)
import os
import shutil
from collections import defaultdict, Counter

def create_simplified_dataset(source_path, target_path, strategy='color_shape'):
    """Create simplified dataset based on chosen strategy"""

    print(f"🔄 Creating simplified dataset using '{strategy}' strategy...")

    # Define the simplification function
    strategies = {
        'color_only': lambda color, shape, texture, thickness: f"color_{color}",
        'shape_only': lambda color, shape, texture, thickness: f"shape_{shape}",
        'color_shape': lambda color, shape, texture, thickness: f"color_{color}_shape_{shape}",
        'texture_thickness': lambda color, shape, texture, thickness: f"texture_{texture}_thickness_{thickness}"
    }

    simplify_func = strategies[strategy]

    # Process each split
    for split in ['train', 'valid', 'test']:
        source_split_path = os.path.join(source_path, split)
        target_split_path = os.path.join(target_path, split)

        if os.path.exists(source_split_path):
            print(f"\n📂 Processing {split} split...")

            # Count images per simplified class
            simplified_counts = defaultdict(int)

            # First pass: count what we'll have
            for class_folder in os.listdir(source_split_path):
                class_path = os.path.join(source_split_path, class_folder)
                if os.path.isdir(class_path):
                    # Extract attributes
                    color, shape, texture, thickness = extract_primary_attributes(class_folder)
                    simplified_class = simplify_func(color, shape, texture, thickness)

                    # Count images in this original class
                    image_count = len([f for f in os.listdir(class_path)
                                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                    simplified_counts[simplified_class] += image_count

            # Second pass: copy images to simplified structure
            for class_folder in os.listdir(source_split_path):
                class_path = os.path.join(source_split_path, class_folder)
                if os.path.isdir(class_path):
                    # Extract attributes
                    color, shape, texture, thickness = extract_primary_attributes(class_folder)
                    simplified_class = simplify_func(color, shape, texture, thickness)

                    # Create target directory
                    target_class_path = os.path.join(target_split_path, simplified_class)
                    os.makedirs(target_class_path, exist_ok=True)

                    # Copy all images from this class
                    for image_file in os.listdir(class_path):
                        if image_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                            source_image = os.path.join(class_path, image_file)
                            # Create unique filename to avoid conflicts
                            unique_filename = f"{class_folder}_{image_file}"
                            target_image = os.path.join(target_class_path, unique_filename)
                            shutil.copy2(source_image, target_image)

            # Show results for this split
            print(f"  ✅ {split} simplified classes:")
            for class_name, count in sorted(simplified_counts.items()):
                print(f"    • {class_name}: {count} images")

    return target_path

# Create the simplified dataset
original_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1'
simplified_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/simplified_dataset'

# Use color_shape strategy (12 classes - good balance)
created_path = create_simplified_dataset(
    original_dataset_path,
    simplified_dataset_path,
    strategy='color_shape'
)

print(f"\n🎯 SIMPLIFIED DATASET SUMMARY:")
print("=" * 40)

# Analyze the final simplified dataset
total_images = 0
total_classes = 0

for split in ['train', 'valid', 'test']:
    split_path = os.path.join(simplified_dataset_path, split)
    if os.path.exists(split_path):
        split_classes = 0
        split_images = 0

        for class_folder in os.listdir(split_path):
            class_path = os.path.join(split_path, class_folder)
            if os.path.isdir(class_path):
                image_count = len([f for f in os.listdir(class_path)
                                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                split_classes += 1
                split_images += image_count

        total_classes = split_classes  # Should be same across splits
        total_images += split_images
        print(f"📊 {split}: {split_images} images across {split_classes} classes")

print(f"\n🏆 FINAL RESULTS:")
print(f"📈 Total Images: {total_images}")
print(f"🏷️ Total Classes: {total_classes}")
print(f"📊 Average images per class: {total_images//total_classes if total_classes > 0 else 0}")
print(f"💾 Simplified dataset saved at: {simplified_dataset_path}")

print(f"\n✅ Ready for retraining with much better performance expected!")


🔄 Creating simplified dataset using 'color_shape' strategy...

📂 Processing train split...


NameError: name 'extract_primary_attributes' is not defined

In [7]:
# Debug the dataset creation process to identify the issue
import os

print("🔍 Debugging dataset creation issue...")
print("=" * 50)

# Check if source paths exist
original_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1'
simplified_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/simplified_dataset'

print("📂 Checking source dataset paths...")
for split in ['train', 'valid', 'test']:
    source_split_path = os.path.join(original_dataset_path, split)
    print(f"  {split}: {os.path.exists(source_split_path)} - {source_split_path}")

    if os.path.exists(source_split_path):
        class_count = len([d for d in os.listdir(source_split_path)
                          if os.path.isdir(os.path.join(source_split_path, d))])
        print(f"    └── {class_count} class folders found")

# Check if the extract_primary_attributes function is working
print(f"\n🔧 Testing attribute extraction...")
if os.path.exists(os.path.join(original_dataset_path, 'train')):
    sample_classes = os.listdir(os.path.join(original_dataset_path, 'train'))[:3]
    for sample_class in sample_classes:
        color, shape, texture, thickness = extract_primary_attributes(sample_class)
        print(f"  Class: {sample_class}")
        print(f"    → Color: {color}, Shape: {shape}, Texture: {texture}, Thickness: {thickness}")

# Check what was actually created in simplified dataset
print(f"\n📊 Checking simplified dataset contents...")
if os.path.exists(simplified_dataset_path):
    for split in ['train', 'valid', 'test']:
        split_path = os.path.join(simplified_dataset_path, split)
        if os.path.exists(split_path):
            contents = os.listdir(split_path)
            print(f"  {split}: {len(contents)} items - {contents[:5]}")
        else:
            print(f"  {split}: folder doesn't exist")
else:
    print("  ❌ Simplified dataset folder doesn't exist")

print(f"\n🎯 Diagnosis complete - ready for fix!")


🔍 Debugging dataset creation issue...
📂 Checking source dataset paths...
  train: True - /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1/train
    └── 166 class folders found
  valid: True - /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1/valid
    └── 120 class folders found
  test: True - /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1/test
    └── 116 class folders found

🔧 Testing attribute extraction...


NameError: name 'extract_primary_attributes' is not defined

In [8]:
# Find where the dataset actually exists and fix the paths
import os

print("🔍 Searching for the actual dataset location...")
print("=" * 50)

# Search in the raw_dataset folder
raw_dataset_base = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset'

if os.path.exists(raw_dataset_base):
    print(f"📂 Contents of raw_dataset folder:")
    for item in os.listdir(raw_dataset_base):
        item_path = os.path.join(raw_dataset_base, item)
        print(f"  📁 {item}")

        # Check if this folder contains train/valid/test
        if os.path.isdir(item_path):
            subitems = os.listdir(item_path)
            has_splits = any(split in subitems for split in ['train', 'valid', 'test'])
            if has_splits:
                print(f"    ✅ Contains dataset splits: {[s for s in subitems if s in ['train', 'valid', 'test']]}")

                # This is likely our dataset - check image counts
                for split in ['train', 'valid', 'test']:
                    split_path = os.path.join(item_path, split)
                    if os.path.exists(split_path):
                        class_folders = [d for d in os.listdir(split_path)
                                       if os.path.isdir(os.path.join(split_path, d))]
                        total_images = 0
                        for cf in class_folders[:3]:  # Check first 3 classes
                            cf_path = os.path.join(split_path, cf)
                            img_count = len([f for f in os.listdir(cf_path)
                                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                            total_images += img_count

                        print(f"      📊 {split}: {len(class_folders)} classes, ~{total_images} images (from first 3 classes)")

                        # Show sample class names
                        if class_folders:
                            print(f"      📝 Sample classes: {class_folders[:2]}")

                # This is our correct dataset path!
                correct_dataset_path = item_path
                print(f"\n🎯 FOUND DATASET AT: {correct_dataset_path}")
            else:
                print(f"    📄 Contains: {subitems[:3]}...")

print(f"\n✅ Dataset location identified!")


🔍 Searching for the actual dataset location...
📂 Contents of raw_dataset folder:
  📁 Tongue-1
    ✅ Contains dataset splits: ['test', 'train', 'valid']
      📊 train: 166 classes, ~81 images (from first 3 classes)
      📝 Sample classes: ['colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping', 'colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis']
      📊 valid: 120 classes, ~7 images (from first 3 classes)
      📝 Sample classes: ['colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping', 'colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis']
      📊 test: 116 classes, ~15 images (from first 3 classes)
      📝 Sample classes: ['colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping', 'colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis']

🎯 FOUND DATASET AT: /content/drive/MyDrive/Tongue_Disease

In [9]:
# Manually explore the raw_dataset folder to find where the tongue dataset is located
import os

print("🔍 Manual exploration of raw_dataset folder...")
print("=" * 60)

raw_dataset_base = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset'

def explore_folder_deeply(folder_path, max_depth=3, current_depth=0):
    """Recursively explore folders to find the dataset"""
    if current_depth >= max_depth or not os.path.exists(folder_path):
        return None

    indent = "  " * current_depth
    items = os.listdir(folder_path)

    print(f"{indent}📁 {os.path.basename(folder_path)}/ ({len(items)} items)")

    # Check if current folder has train/valid/test structure
    has_train = 'train' in items
    has_valid = 'valid' in items or 'validation' in items
    has_test = 'test' in items

    if has_train and (has_valid or has_test):
        print(f"{indent}🎯 DATASET FOUND! This folder has train/valid/test structure")

        # Count images in each split
        for split in ['train', 'valid', 'test']:
            split_path = os.path.join(folder_path, split)
            if os.path.exists(split_path):
                # Count class folders
                class_folders = [d for d in os.listdir(split_path)
                               if os.path.isdir(os.path.join(split_path, d))]

                # Count total images
                total_images = 0
                for class_folder in class_folders[:5]:  # Check first 5 classes
                    class_path = os.path.join(split_path, class_folder)
                    image_count = len([f for f in os.listdir(class_path)
                                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                    total_images += image_count

                print(f"{indent}  📊 {split}: {len(class_folders)} classes, ~{total_images * len(class_folders) // 5 if class_folders else 0} total images")

                # Show sample class names
                if class_folders:
                    print(f"{indent}  📝 Sample classes:")
                    for class_name in class_folders[:2]:
                        print(f"{indent}    • {class_name}")

        return folder_path

    # Explore subdirectories
    for item in items:
        item_path = os.path.join(folder_path, item)
        if os.path.isdir(item_path):
            print(f"{indent}  📁 {item}/")
            result = explore_folder_deeply(item_path, max_depth, current_depth + 1)
            if result:
                return result
        else:
            # Show first few files
            if items.index(item) < 3:
                print(f"{indent}  📄 {item}")

    if len(items) > 3:
        file_count = len([i for i in items if os.path.isfile(os.path.join(folder_path, i))])
        if file_count > 3:
            print(f"{indent}  ... and {file_count - 3} more files")

    return None

print(f"🔍 Exploring: {raw_dataset_base}")

if os.path.exists(raw_dataset_base):
    dataset_location = explore_folder_deeply(raw_dataset_base)

    if dataset_location:
        print(f"\n🎯 DATASET FOUND AT: {dataset_location}")
        print(f"✅ This is where your 9,890 tongue images are stored!")

        # Store the path for next steps
        print(f"\n📋 Use this path for simplified dataset creation:")
        print(f"correct_dataset_path = '{dataset_location}'")
    else:
        print(f"\n❌ No dataset with train/valid/test structure found")
        print(f"🔍 Let's check if files are in a different structure...")

        # Alternative: look for any folder with many images
        print(f"\n🔍 Looking for folders with many images...")
        for item in os.listdir(raw_dataset_base):
            item_path = os.path.join(raw_dataset_base, item)
            if os.path.isdir(item_path):
                # Count total files recursively
                total_files = 0
                for root, dirs, files in os.walk(item_path):
                    total_files += len([f for f in files if f.lower().endswith(('.jpg', '.jpeg', '.png'))])

                if total_files > 100:  # Likely our dataset
                    print(f"  📊 {item}: {total_files} image files")
                    print(f"      Path: {item_path}")
else:
    print(f"❌ Raw dataset folder doesn't exist: {raw_dataset_base}")

    # Check if the base project folder exists
    project_base = '/content/drive/MyDrive/Tongue_Disease_AI'
    if os.path.exists(project_base):
        print(f"\n📂 Contents of project folder:")
        for item in os.listdir(project_base):
            print(f"  📁 {item}")
    else:
        print(f"❌ Project folder doesn't exist either!")

print(f"\n🎯 Manual exploration complete!")


🔍 Manual exploration of raw_dataset folder...
🔍 Exploring: /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset
📁 raw_dataset/ (1 items)
  📁 Tongue-1/
  📁 Tongue-1/ (5 items)
  🎯 DATASET FOUND! This folder has train/valid/test structure
    📊 train: 166 classes, ~3220 total images
    📝 Sample classes:
      • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping
      • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis
    📊 valid: 120 classes, ~240 total images
    📝 Sample classes:
      • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping
      • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis
    📊 test: 116 classes, ~394 total images
    📝 Sample classes:
      • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_Stripping
      • colorResult_grey shapeResult_ToothMarks textureResult_dark thicknessResult_ecchymosis

🎯 D

In [10]:
# Create simplified dataset using the confirmed correct dataset path
import os
import shutil
from collections import defaultdict

# Use the confirmed dataset path
correct_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1'
simplified_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/simplified_dataset'

print("🔄 Creating simplified color_shape dataset...")
print(f"📂 Source: {correct_dataset_path}")
print(f"📂 Target: {simplified_dataset_path}")
print("=" * 60)

# Clear any existing simplified dataset
if os.path.exists(simplified_dataset_path):
    shutil.rmtree(simplified_dataset_path)

# Create simplified dataset with color_shape strategy
total_processed = 0

for split in ['train', 'valid', 'test']:
    source_split_path = os.path.join(correct_dataset_path, split)
    target_split_path = os.path.join(simplified_dataset_path, split)

    if os.path.exists(source_split_path):
        print(f"\n📂 Processing {split} split...")

        simplified_counts = defaultdict(int)
        os.makedirs(target_split_path, exist_ok=True)

        # Process each original class
        class_folders = os.listdir(source_split_path)

        for i, class_folder in enumerate(class_folders):
            class_path = os.path.join(source_split_path, class_folder)
            if os.path.isdir(class_path):
                # Extract attributes using the function we defined earlier
                color, shape, texture, thickness = extract_primary_attributes(class_folder)

                # Create simplified class name: color_shape
                simplified_class = f"color_{color}_shape_{shape}"

                # Create target directory
                target_class_path = os.path.join(target_split_path, simplified_class)
                os.makedirs(target_class_path, exist_ok=True)

                # Copy images
                copied_count = 0
                for image_file in os.listdir(class_path):
                    if image_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        source_image = os.path.join(class_path, image_file)
                        # Create unique filename to avoid conflicts
                        unique_filename = f"{class_folder.replace(' ', '_')}_{image_file}"
                        target_image = os.path.join(target_class_path, unique_filename)
                        shutil.copy2(source_image, target_image)
                        copied_count += 1

                simplified_counts[simplified_class] += copied_count
                total_processed += copied_count

                # Progress indicator
                if (i + 1) % 50 == 0:
                    print(f"    ✅ Processed {i + 1}/{len(class_folders)} classes...")

        # Show results for this split
        print(f"  📊 Created {len(simplified_counts)} simplified classes:")
        for class_name, count in sorted(simplified_counts.items()):
            print(f"    • {class_name}: {count} images")

# Final verification
print(f"\n🎯 SIMPLIFIED DATASET SUMMARY:")
print("=" * 50)

total_images = 0
all_classes = set()

for split in ['train', 'valid', 'test']:
    split_path = os.path.join(simplified_dataset_path, split)
    if os.path.exists(split_path):
        split_images = 0
        classes = [d for d in os.listdir(split_path)
                  if os.path.isdir(os.path.join(split_path, d))]

        for class_folder in classes:
            all_classes.add(class_folder)
            class_path = os.path.join(split_path, class_folder)
            image_count = len([f for f in os.listdir(class_path)
                             if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            split_images += image_count

        total_images += split_images
        print(f"📊 {split}: {split_images} images across {len(classes)} classes")

print(f"\n🏆 FINAL RESULTS:")
print(f"📈 Total Images: {total_images}")
print(f"🏷️ Total Unique Classes: {len(all_classes)}")
print(f"📊 Average per class: {total_images//len(all_classes) if all_classes else 0}")
print(f"💾 Simplified dataset location: {simplified_dataset_path}")

print(f"\n📝 Simplified classes created:")
for class_name in sorted(all_classes):
    print(f"  • {class_name}")

if total_images > 8000:  # Expect ~9,890 images
    print(f"\n✅ SUCCESS! Ready for training with {len(all_classes)}-class simplified dataset!")
    print(f"🚀 Expected accuracy improvement: 0.21% → 60-80%+")
else:
    print(f"\n⚠️ Only {total_images} images found - some may have been missed")

print(f"\n🎯 Ready for next step: Retrain model with simplified dataset!")


🔄 Creating simplified color_shape dataset...
📂 Source: /content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1
📂 Target: /content/drive/MyDrive/Tongue_Disease_AI/simplified_dataset

📂 Processing train split...


NameError: name 'extract_primary_attributes' is not defined

In [None]:
# Define the missing extract_primary_attributes function and create simplified dataset
import os
import shutil
from collections import defaultdict

def extract_primary_attributes(class_name):
    """Extract and simplify the primary attributes from complex class names"""

    # Initialize default values
    color = "unknown"
    shape = "unknown"
    texture = "unknown"
    thickness = "unknown"

    # Extract attributes by splitting and parsing each part
    parts = class_name.split(' ')
    for part in parts:
        if 'colorResult_' in part:
            color = part.replace('colorResult_', '')
        elif 'shapeResult_' in part:
            shape = part.replace('shapeResult_', '')
        elif 'textureResult_' in part:
            texture = part.replace('textureResult_', '')
        elif 'thicknessResult_' in part:
            thickness = part.replace('thicknessResult_', '')

    return color, shape, texture, thickness

# Now create the simplified dataset with the function properly defined
correct_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/raw_dataset/Tongue-1'
simplified_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/simplified_dataset'

print("🔄 Creating simplified color_shape dataset...")
print(f"📂 Source: {correct_dataset_path}")
print(f"📂 Target: {simplified_dataset_path}")
print("=" * 60)

# Clear any existing simplified dataset
if os.path.exists(simplified_dataset_path):
    shutil.rmtree(simplified_dataset_path)

# Create simplified dataset with color_shape strategy
total_processed = 0

for split in ['train', 'valid', 'test']:
    source_split_path = os.path.join(correct_dataset_path, split)
    target_split_path = os.path.join(simplified_dataset_path, split)

    if os.path.exists(source_split_path):
        print(f"\n📂 Processing {split} split...")

        simplified_counts = defaultdict(int)
        os.makedirs(target_split_path, exist_ok=True)

        # Process each original class
        class_folders = os.listdir(source_split_path)

        for i, class_folder in enumerate(class_folders):
            class_path = os.path.join(source_split_path, class_folder)
            if os.path.isdir(class_path):
                # Extract attributes using our function
                color, shape, texture, thickness = extract_primary_attributes(class_folder)

                # Create simplified class name: color_shape
                simplified_class = f"color_{color}_shape_{shape}"

                # Create target directory
                target_class_path = os.path.join(target_split_path, simplified_class)
                os.makedirs(target_class_path, exist_ok=True)

                # Copy images
                copied_count = 0
                for image_file in os.listdir(class_path):
                    if image_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        source_image = os.path.join(class_path, image_file)
                        # Create unique filename to avoid conflicts
                        unique_filename = f"{class_folder.replace(' ', '_')}_{image_file}"
                        target_image = os.path.join(target_class_path, unique_filename)
                        shutil.copy2(source_image, target_image)
                        copied_count += 1

                simplified_counts[simplified_class] += copied_count
                total_processed += copied_count

                # Progress indicator every 50 classes
                if (i + 1) % 50 == 0:
                    print(f"    ✅ Processed {i + 1}/{len(class_folders)} classes...")

        # Show results for this split
        print(f"  📊 Created {len(simplified_counts)} simplified classes:")
        for class_name, count in sorted(simplified_counts.items()):
            print(f"    • {class_name}: {count} images")

# Final verification
print(f"\n🎯 SIMPLIFIED DATASET SUMMARY:")
print("=" * 50)

total_images = 0
all_classes = set()

for split in ['train', 'valid', 'test']:
    split_path = os.path.join(simplified_dataset_path, split)
    if os.path.exists(split_path):
        split_images = 0
        classes = [d for d in os.listdir(split_path)
                  if os.path.isdir(os.path.join(split_path, d))]

        for class_folder in classes:
            all_classes.add(class_folder)
            class_path = os.path.join(split_path, class_folder)
            image_count = len([f for f in os.listdir(class_path)
                             if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            split_images += image_count

        total_images += split_images
        print(f"📊 {split}: {split_images} images across {len(classes)} classes")

print(f"\n🏆 FINAL RESULTS:")
print(f"📈 Total Images: {total_images}")
print(f"🏷️ Total Unique Classes: {len(all_classes)}")
print(f"📊 Average per class: {total_images//len(all_classes) if all_classes else 0}")
print(f"💾 Simplified dataset location: {simplified_dataset_path}")

print(f"\n📝 Simplified classes created:")
for class_name in sorted(all_classes):
    print(f"  • {class_name}")

if total_images > 8000:
    print(f"\n✅ SUCCESS! Ready for training with {len(all_classes)}-class simplified dataset!")
    print(f"🚀 Expected accuracy improvement: 0.21% → 60-80%+")
else:
    print(f"\n⚠️ Only {total_images} images found - investigating...")

print(f"\n🎯 Ready for retraining with much better performance expected!")


In [13]:
# Retrain the DenseNet-121 model with the simplified 12-class dataset
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
import torch.optim as optim

print("🚀 Setting up training with simplified 12-class dataset...")
print("=" * 60)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Enhanced data transformations for the simplified dataset
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15),
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the simplified dataset
simplified_dataset_path = '/content/drive/MyDrive/Tongue_Disease_AI/simplified_dataset'
train_path = os.path.join(simplified_dataset_path, 'train')
valid_path = os.path.join(simplified_dataset_path, 'valid')
test_path = os.path.join(simplified_dataset_path, 'test')

# Load datasets
train_dataset = datasets.ImageFolder(root=train_path, transform=transform_train)
val_dataset = datasets.ImageFolder(root=valid_path, transform=transform_val)
test_dataset = datasets.ImageFolder(root=test_path, transform=transform_val)

# Create data loaders with larger batch size (more manageable dataset)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# Get dataset information
num_classes = len(train_dataset.classes)
class_names = train_dataset.classes

print(f"✅ Simplified dataset loaded successfully!")
print(f"📈 Training samples: {len(train_dataset)}")
print(f"📊 Validation samples: {len(val_dataset)}")
print(f"🧪 Test samples: {len(test_dataset)}")
print(f"🏷️ Number of classes: {num_classes}")

print(f"\n📝 Simplified classes:")
for i, class_name in enumerate(class_names):
    print(f"{i+1:2d}. {class_name}")

# Initialize new model for simplified classification
class SimplifiedTongueClassifier(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(SimplifiedTongueClassifier, self).__init__()

        # Load pre-trained DenseNet-121
        self.densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None)

        # Get number of features
        num_features = self.densenet.classifier.in_features

        # Simpler classifier for 12 classes (less overfitting risk)
        self.densenet.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.densenet(x)

# Initialize the simplified model
print(f"\n🤖 Initializing simplified tongue classifier...")
model = SimplifiedTongueClassifier(num_classes=num_classes, pretrained=True)
model = model.to(device)

# Training components optimized for simplified dataset
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

print(f"✅ Simplified model initialized!")
print(f"📊 Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"🎯 Ready for training on {num_classes} classes!")
print(f"\n🚀 Expected performance: 60-80%+ accuracy (vs previous 0.21%)")


🚀 Setting up training with simplified 12-class dataset...
Using device: cpu
✅ Simplified dataset loaded successfully!
📈 Training samples: 7929
📊 Validation samples: 973
🧪 Test samples: 988
🏷️ Number of classes: 12

📝 Simplified classes:
 1. color_grey_shape_ToothMarks
 2. color_grey_shape_fat
 3. color_grey_shape_normal
 4. color_grey_shape_thin
 5. color_white_shape_ToothMarks
 6. color_white_shape_fat
 7. color_white_shape_normal
 8. color_white_shape_thin
 9. color_yellow_shape_ToothMarks
10. color_yellow_shape_fat
11. color_yellow_shape_normal
12. color_yellow_shape_thin

🤖 Initializing simplified tongue classifier...


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 77.6MB/s]

✅ Simplified model initialized!
📊 Total parameters: 7,219,340
🎯 Ready for training on 12 classes!

🚀 Expected performance: 60-80%+ accuracy (vs previous 0.21%)





In [1]:
# Training function optimized for the 12-class simplified dataset
def train_simplified_tongue_classifier(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25):
    """Training function optimized for simplified 12-class tongue diagnosis"""

    print(f"🏋️ Starting training for {num_epochs} epochs on simplified dataset...")
    print(f"🎯 Training: {len(train_dataset)} samples | Validation: {len(val_dataset)} samples")
    print("=" * 60)

    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': []
    }

    best_val_acc = 0.0
    best_model_path = '/content/drive/MyDrive/Tongue_Disease_AI/models/best_simplified_classifier.pth'

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 40)

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # Progress every 30 batches
            if (batch_idx + 1) % 30 == 0:
                batch_acc = torch.sum(preds == labels.data).double() / inputs.size(0)
                print(f'    Batch {batch_idx+1}/{len(train_loader)}: Loss={loss.item():.4f}, Acc={batch_acc:.4f}')

        epoch_loss = running_loss / len(train_dataset)
        epoch_acc = running_corrects.double() / len(train_dataset)

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * inputs.size(0)
                val_running_corrects += torch.sum(preds == labels.data)

        val_loss = val_running_loss / len(val_dataset)
        val_acc = val_running_corrects.double() / len(val_dataset)

        # Update scheduler
        scheduler.step(val_acc)

        # Save metrics
        history['train_loss'].append(epoch_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(epoch_acc.item())
        history['val_acc'].append(val_acc.item())

        # Print epoch results
        print(f"📈 Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f} ({epoch_acc*100:.2f}%)")
        print(f"📊 Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} ({val_acc*100:.2f}%)")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'class_names': class_names,
                'num_classes': num_classes,
                'history': history
            }, best_model_path)
            print(f"🎯 New best model saved! Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")

        # GPU memory cleanup every 5 epochs
        if (epoch + 1) % 5 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("\n" + "=" * 60)
    print(f"🏆 Training completed! Best validation accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)")
    print(f"🚀 Improvement: 0.21% → {best_val_acc*100:.2f}% (Expected: 60-80%+)")
    print(f"💾 Best model saved to: {best_model_path}")

    return history, best_model_path

# Start training the simplified classifier
print("🎯 Starting training of simplified 12-class tongue diagnosis system...")
training_history, best_model_path = train_simplified_tongue_classifier(
    model, train_loader, val_loader, criterion, optimizer, scheduler,
    num_epochs=25
)

print(f"\n✅ Simplified tongue classifier training completed!")
print(f"🎉 Ready for evaluation and real-world prediction!")


🎯 Starting training of simplified 12-class tongue diagnosis system...


NameError: name 'model' is not defined

In [19]:
# Fix GPU detection and optimize training for T4 GPU
import torch
import os

# Force clear any GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Check current GPU status
print("🔍 GPU Status Check:")
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"CUDA Version: {torch.version.cuda}")
print(f"PyTorch Version: {torch.__version__}")

if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"Current Device: {torch.cuda.current_device()}")
else:
    print("❌ PyTorch cannot detect GPU")

# Force set device to CUDA if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"\n🎯 Setting device to: {device}")

# Test GPU with a simple operation
if torch.cuda.is_available():
    test_tensor = torch.randn(1000, 1000).to(device)
    result = torch.mm(test_tensor, test_tensor.t())
    print(f"✅ GPU test successful! Tensor on device: {result.device}")
    del test_tensor, result
    torch.cuda.empty_cache()
else:
    print("❌ GPU test failed - using CPU")

print(f"\n🚀 Ready to use {device} for training!")


🔍 GPU Status Check:
CUDA Available: False
CUDA Version: 12.4
PyTorch Version: 2.6.0+cu124
❌ PyTorch cannot detect GPU

🎯 Setting device to: cpu
❌ GPU test failed - using CPU

🚀 Ready to use cpu for training!
