In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split

In [2]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [3]:
dataset = datasets.ImageFolder(root='Rice_Image_Dataset', transform=transform)

In [4]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [5]:
googlenet = models.googlenet(pretrained=True)

num_features = googlenet.fc.in_features
num_classes = 5  
googlenet.fc = nn.Linear(num_features, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(googlenet.parameters(), lr=0.001, momentum=0.9)



In [6]:
num_epochs = 1
for epoch in range(num_epochs):
    googlenet.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs, labels

        optimizer.zero_grad()
        outputs = googlenet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}')

Epoch 1/1, Loss: 0.08901187708787024


In [13]:
from torchviz import make_dot
dot = make_dot(outputs, params=dict(googlenet.named_parameters()))

dot.render("googlenet_rice_classifier", format="png")

'googlenet_rice_classifier.png'

In [7]:
googlenet.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs, labels
        outputs = googlenet(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
accuracy = correct / total
print(f'Test Accuracy: {accuracy * 100:.2f}%')

Test Accuracy: 99.81%


In [8]:
torch.save(googlenet.state_dict(), 'googlenet_rice_classifier.pth')