<a href="https://colab.research.google.com/github/ilhamazhar1308/Rock-Type-Classification-Using-Deep-Learning-Approach-Based-on-ResNet-34-Architecture/blob/main/DeepLearningGeofisika_Resnet34_Kelompok_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import re

In [None]:
from google.colab import drive
drive.mount('/content/drive')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Mounted at /content/drive
Device: cuda


In [None]:
data_dir = "/content/drive/MyDrive/Rocks"
full_ds = datasets.ImageFolder(root=data_dir)

class_name = full_ds.classes
num_classes = len(class_name)

print("Jumlah kelas:", num_classes)
print("Nama kelas:", class_name)
print("Total gambar:", len(full_ds))


Jumlah kelas: 53
Nama kelas: ['Amphibolite', 'Andesite', 'Anthracite', 'Basalt', 'Blueschist', 'Breccia', 'Carbonatite', 'Chalk', 'Chert', 'Coal', 'Conglomerate', 'Diamictite', 'Dolomite', 'Eclogite', 'Evaporite', 'Flint', 'Gabbro', 'Gneiss', 'Granite', 'Granulite', 'Greenschist', 'Greywacke', 'Hornfels', 'Komatiite', 'Limestone', 'Marble', 'Migmatite', 'Mudstone', 'Obsidian', 'Oil_shale', 'Oolite', 'Pegmatite', 'Phyllite', 'Porphyry', 'Pumice', 'Pyroxenite', 'Quartz_diorite', 'Quartz_monzonite', 'Quartzite', 'Quartzolite', 'Rhyolite', 'Sandstone', 'Scoria', 'Serpentinite', 'Shale', 'Siltstone', 'Slate', 'Talc_carbonate', 'Tephrite', 'Travertine', 'Tuff', 'Turbidite', 'Wackestone']
Total gambar: 2343


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_tfms = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(0.3,0.3,0.3),
    transforms.RandomPerspective(0.3, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

test_tfms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])


In [None]:
train_size = int(0.8 * len(full_ds))
test_size  = len(full_ds) - train_size

train_idx, test_idx = torch.utils.data.random_split(
    full_ds, [train_size, test_size]
)

train_ds = datasets.ImageFolder(data_dir, transform=train_tfms)
test_ds  = datasets.ImageFolder(data_dir, transform=test_tfms)

train_ds.samples = [full_ds.samples[i] for i in train_idx.indices]
test_ds.samples  = [full_ds.samples[i] for i in test_idx.indices]

print("Train:", len(train_ds))
print("Test :", len(test_ds))


Train: 1874
Test : 469


In [None]:
batch_size = 32

train_dl = DataLoader(
    train_ds, batch_size=batch_size,
    shuffle=True, num_workers=2, pin_memory=True
)

test_dl = DataLoader(
    test_ds, batch_size=batch_size,
    shuffle=False, num_workers=2, pin_memory=True
)


In [None]:
model = models.resnet34(
    weights=models.ResNet34_Weights.IMAGENET1K_V1
)

# Freeze
for param in model.parameters():
    param.requires_grad = False

# Ganti FC
model.fc = nn.Sequential(
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, num_classes)
)

# Fine tuning layer terakhir
for name, param in model.named_parameters():
    if "layer4" in name:
        param.requires_grad = True

model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 169MB/s]


In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.15)

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4, weight_decay=1e-3
)


In [None]:
def accuracy(logits, y):
    return (logits.argmax(1) == y).float().mean().item()


In [None]:
def train_one_epoch(model, loader):
    model.train()
    loss_sum, acc_sum = 0, 0

    for x, y in tqdm(loader, desc="Train", leave=False):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item() * x.size(0)
        acc_sum  += accuracy(out, y) * x.size(0)

    return loss_sum / len(loader.dataset), acc_sum / len(loader.dataset)


In [None]:
ckpt_dir = "/content/drive/MyDrive/Checkpoints_Rocks_FINAL"
os.makedirs(ckpt_dir, exist_ok=True)

epochs = 125
for ep in range(1, epochs + 1):
    loss, acc = train_one_epoch(model, train_dl)

    print(f"Epoch {ep:03d} | Loss {loss:.4f} | Acc {acc*100:.2f}%")

    if ep % 5 == 0:
        torch.save({
            "epoch": ep,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }, f"{ckpt_dir}/checkpoint_{ep}.pth")

        print("✅ Checkpoint saved")




Epoch 001 | Loss 4.0123 | Acc 3.95%




Epoch 002 | Loss 3.6804 | Acc 13.13%




Epoch 003 | Loss 3.4883 | Acc 18.78%




Epoch 004 | Loss 3.3498 | Acc 23.37%




Epoch 005 | Loss 3.2413 | Acc 26.31%
✅ Checkpoint saved




Epoch 006 | Loss 3.1055 | Acc 30.68%




Epoch 007 | Loss 3.0562 | Acc 32.23%




Epoch 008 | Loss 2.9709 | Acc 33.30%




Epoch 009 | Loss 2.8952 | Acc 36.50%




Epoch 010 | Loss 2.8318 | Acc 39.86%
✅ Checkpoint saved




Epoch 011 | Loss 2.8209 | Acc 39.01%




Epoch 012 | Loss 2.7357 | Acc 43.17%




Epoch 013 | Loss 2.6973 | Acc 42.64%




Epoch 014 | Loss 2.6203 | Acc 46.21%




Epoch 015 | Loss 2.6095 | Acc 45.68%
✅ Checkpoint saved




Epoch 016 | Loss 2.5660 | Acc 48.61%




Epoch 017 | Loss 2.5126 | Acc 49.89%




Epoch 018 | Loss 2.4570 | Acc 52.03%




Epoch 019 | Loss 2.4092 | Acc 52.67%




Epoch 020 | Loss 2.3999 | Acc 53.15%
✅ Checkpoint saved




Epoch 021 | Loss 2.3475 | Acc 54.75%




Epoch 022 | Loss 2.3383 | Acc 55.44%




Epoch 023 | Loss 2.3031 | Acc 56.67%




Epoch 024 | Loss 2.2459 | Acc 58.91%




Epoch 025 | Loss 2.1945 | Acc 61.15%
✅ Checkpoint saved




Epoch 026 | Loss 2.1999 | Acc 60.67%




Epoch 027 | Loss 2.1645 | Acc 60.99%




Epoch 028 | Loss 2.0935 | Acc 64.30%




Epoch 029 | Loss 2.0779 | Acc 64.99%




Epoch 030 | Loss 2.0542 | Acc 64.67%
✅ Checkpoint saved




Epoch 031 | Loss 2.0808 | Acc 65.26%




Epoch 032 | Loss 2.0147 | Acc 67.50%




Epoch 033 | Loss 2.0099 | Acc 67.34%




Epoch 034 | Loss 1.9677 | Acc 68.09%




Epoch 035 | Loss 1.9564 | Acc 69.74%
✅ Checkpoint saved




Epoch 036 | Loss 1.9412 | Acc 69.42%




Epoch 037 | Loss 1.9266 | Acc 70.22%




Epoch 038 | Loss 1.8878 | Acc 72.79%




Epoch 039 | Loss 1.8865 | Acc 71.82%




Epoch 040 | Loss 1.8475 | Acc 72.04%
✅ Checkpoint saved




Epoch 041 | Loss 1.8690 | Acc 73.32%




Epoch 042 | Loss 1.8198 | Acc 74.17%




Epoch 043 | Loss 1.8005 | Acc 74.92%




Epoch 044 | Loss 1.8057 | Acc 74.71%




Epoch 045 | Loss 1.7400 | Acc 77.64%
✅ Checkpoint saved




Epoch 046 | Loss 1.7589 | Acc 75.29%




Epoch 047 | Loss 1.7619 | Acc 74.65%




Epoch 048 | Loss 1.7875 | Acc 75.19%




Epoch 049 | Loss 1.7413 | Acc 76.36%




Epoch 050 | Loss 1.7188 | Acc 76.79%
✅ Checkpoint saved




Epoch 051 | Loss 1.7542 | Acc 75.99%




Epoch 052 | Loss 1.7359 | Acc 77.37%




Epoch 053 | Loss 1.6883 | Acc 78.66%




Epoch 054 | Loss 1.6803 | Acc 79.56%




Epoch 055 | Loss 1.6825 | Acc 79.03%
✅ Checkpoint saved




Epoch 056 | Loss 1.6899 | Acc 78.01%




Epoch 057 | Loss 1.6894 | Acc 78.23%




Epoch 058 | Loss 1.6738 | Acc 78.28%




Epoch 059 | Loss 1.6399 | Acc 79.99%




Epoch 060 | Loss 1.6463 | Acc 79.19%
✅ Checkpoint saved




Epoch 061 | Loss 1.6351 | Acc 80.04%




Epoch 062 | Loss 1.6424 | Acc 80.15%




Epoch 063 | Loss 1.5997 | Acc 80.90%




Epoch 064 | Loss 1.6666 | Acc 79.35%




Epoch 065 | Loss 1.6174 | Acc 81.06%
✅ Checkpoint saved




Epoch 066 | Loss 1.6503 | Acc 80.31%




Epoch 067 | Loss 1.6098 | Acc 81.00%




Epoch 068 | Loss 1.5732 | Acc 82.44%




Epoch 069 | Loss 1.5986 | Acc 81.43%




Epoch 070 | Loss 1.5681 | Acc 82.02%
✅ Checkpoint saved




Epoch 071 | Loss 1.5776 | Acc 82.66%




Epoch 072 | Loss 1.5805 | Acc 81.54%




Epoch 073 | Loss 1.5681 | Acc 81.43%




Epoch 074 | Loss 1.5978 | Acc 81.70%




Epoch 075 | Loss 1.5819 | Acc 81.16%
✅ Checkpoint saved




Epoch 076 | Loss 1.5521 | Acc 83.08%




Epoch 077 | Loss 1.5421 | Acc 82.76%




Epoch 078 | Loss 1.5780 | Acc 82.18%




Epoch 079 | Loss 1.5811 | Acc 81.64%




Epoch 080 | Loss 1.5430 | Acc 83.03%
✅ Checkpoint saved




Epoch 081 | Loss 1.5768 | Acc 81.06%




Epoch 082 | Loss 1.4999 | Acc 85.06%




Epoch 083 | Loss 1.5262 | Acc 83.30%




Epoch 084 | Loss 1.5379 | Acc 83.78%




Epoch 085 | Loss 1.5388 | Acc 82.50%
✅ Checkpoint saved




Epoch 086 | Loss 1.5227 | Acc 83.62%




Epoch 087 | Loss 1.5422 | Acc 82.55%




Epoch 088 | Loss 1.5149 | Acc 84.20%




Epoch 089 | Loss 1.5079 | Acc 83.99%




Epoch 090 | Loss 1.5320 | Acc 82.82%
✅ Checkpoint saved




Epoch 091 | Loss 1.5163 | Acc 83.56%




Epoch 092 | Loss 1.4931 | Acc 84.31%




Epoch 093 | Loss 1.5050 | Acc 83.83%




Epoch 094 | Loss 1.5087 | Acc 84.36%




Epoch 095 | Loss 1.4878 | Acc 84.85%
✅ Checkpoint saved




Epoch 096 | Loss 1.5219 | Acc 83.14%




Epoch 097 | Loss 1.5037 | Acc 84.31%




Epoch 098 | Loss 1.4779 | Acc 84.47%




Epoch 099 | Loss 1.5029 | Acc 83.51%




Epoch 100 | Loss 1.4942 | Acc 83.94%
✅ Checkpoint saved




Epoch 101 | Loss 1.5043 | Acc 83.78%




Epoch 102 | Loss 1.4673 | Acc 85.54%




Epoch 103 | Loss 1.4628 | Acc 84.85%




Epoch 104 | Loss 1.5034 | Acc 83.99%




Epoch 105 | Loss 1.4777 | Acc 84.26%
✅ Checkpoint saved




Epoch 106 | Loss 1.4576 | Acc 85.54%




Epoch 107 | Loss 1.4439 | Acc 85.59%




Epoch 108 | Loss 1.4573 | Acc 85.81%




Epoch 109 | Loss 1.4648 | Acc 84.74%




Epoch 110 | Loss 1.4398 | Acc 86.55%
✅ Checkpoint saved




Epoch 111 | Loss 1.4486 | Acc 85.27%




Epoch 112 | Loss 1.4406 | Acc 86.29%




Epoch 113 | Loss 1.4660 | Acc 84.15%




Epoch 114 | Loss 1.4435 | Acc 85.70%




Epoch 115 | Loss 1.4483 | Acc 84.85%
✅ Checkpoint saved




Epoch 116 | Loss 1.4178 | Acc 87.46%




Epoch 117 | Loss 1.4571 | Acc 85.22%




Epoch 118 | Loss 1.4634 | Acc 84.90%




Epoch 119 | Loss 1.4346 | Acc 86.55%




Epoch 120 | Loss 1.4272 | Acc 86.61%
✅ Checkpoint saved




Epoch 121 | Loss 1.4436 | Acc 85.17%




Epoch 122 | Loss 1.4395 | Acc 85.38%




Epoch 123 | Loss 1.4397 | Acc 85.65%




Epoch 124 | Loss 1.4489 | Acc 84.85%




Epoch 125 | Loss 1.4409 | Acc 86.07%
✅ Checkpoint saved


In [None]:
ckpt_dir = "/content/drive/MyDrive/Checkpoints_Rocks_FINAL"

def extract_epoch(fname):
    return int(re.search(r"checkpoint_(\d+)\.pth", fname).group(1))

checkpoint_files = [
    f for f in os.listdir(ckpt_dir) if f.endswith(".pth")
]

assert len(checkpoint_files) > 0, "Checkpoint tidak ditemukan!"

checkpoint_files.sort(key=extract_epoch)

latest_ckpt = os.path.join(ckpt_dir, checkpoint_files[-1])
print("Memuat checkpoint:", latest_ckpt)

checkpoint = torch.load(latest_ckpt, map_location=device)

model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

start_epoch = checkpoint["epoch"] + 1
print("Lanjut dari epoch:", start_epoch)

epochs = 70

for ep in range(start_epoch, epochs + 1):
    loss, acc = train_one_epoch(model, train_dl)

    print(f"Epoch {ep:03d} | Loss {loss:.4f} | Acc {acc*100:.2f}%")

    if ep % 5 == 0:
        torch.save({
            "epoch": ep,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }, f"{ckpt_dir}/checkpoint_{ep}.pth")

        print("✅ Checkpoint saved")


Memuat checkpoint: /content/drive/MyDrive/Checkpoints_Rocks_FINAL/checkpoint_125.pth
Lanjut dari epoch: 126


In [None]:
ckpt = torch.load(f"{ckpt_dir}/checkpoint_{epochs}.pth",
                  map_location=device)

model.load_state_dict(ckpt["model"])
print("Checkpoint loaded, siap prediksi")


Checkpoint loaded, siap prediksi


In [None]:
def extract_epoch(fname):
    return int(re.search(r"checkpoint_(\d+)\.pth", fname).group(1))

ckpt_files = [f for f in os.listdir(ckpt_dir) if f.endswith(".pth")]
ckpt_files.sort(key=extract_epoch)

latest_ckpt = os.path.join(ckpt_dir, ckpt_files[-1])
print("Loading:", latest_ckpt)

ckpt = torch.load(latest_ckpt, map_location=device)
model.load_state_dict(ckpt["model"])

print("✅ Checkpoint terakhir dimuat, siap prediksi")


Loading: /content/drive/MyDrive/Checkpoints_Rocks_FINAL/checkpoint_125.pth
✅ Checkpoint terakhir dimuat, siap prediksi


In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import os

def predict_image_with_photo(img_path, topk=5):
    assert os.path.exists(img_path), "File gambar tidak ditemukan"

    model.eval()
    img = Image.open(img_path).convert("RGB")

    # Transform
    x = test_tfms(img).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(x)
        probs = F.softmax(logits, dim=1)[0]


    topk_vals, topk_idx = probs.topk(topk)


    plt.figure(figsize=(6,6))
    plt.imshow(img)
    plt.axis("off")

    plt.title(
        f"PREDIKSI: {class_name[topk_idx[0]]}\n"
        f"CONFIDENCE: {topk_vals[0]*100:.2f}%"
    )

    plt.show()


    print("TOP PREDIKSI:")
    for i, v in zip(topk_idx, topk_vals):
        print(f"{class_name[i]:<25} {v*100:.2f}%")

    if topk_vals[0] < 0.70:
        print("⚠️ Model kurang yakin")

    return class_name[topk_idx[0]]


In [None]:
predict_image_with_photo(
    "/content/drive/MyDrive/DEEPLEARN/testbatu3.jpg"
)


AssertionError: File gambar tidak ditemukan