In [2]:
!cp -r /kaggle/input/garbage-classification/garbage_classification garbage_classification

In [3]:
!rm -r garbage_classification/clothes
!rm -r garbage_classification/shoes

In [12]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder

In [13]:
transform = transforms.Compose([
    transforms.Resize([256, 256]),
    transforms.Grayscale(), 
    transforms.ToTensor()
])
dataset = ImageFolder('garbage_classification',transform = transform)

In [14]:
from torch.utils.data import DataLoader

In [15]:
loader = DataLoader(dataset, 64, True)

In [None]:
dataset.classes

In [None]:
for image, label in dataset:
    print(image.shape)
    print(label)
    print(dataset.classes[label])
    break

In [18]:
from torch import nn, optim

In [19]:
import torch

In [20]:
class Classification(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential( #1, 256, 256
            nn.Conv2d(1, 16, 4, 2, 1), #16, 128, 128
            nn.ReLU(),
            nn.Conv2d(16, 32, 4, 2, 1), #32, 64, 64
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), #64, 32, 32
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), #128, 16, 16
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), #256, 8, 8
            nn.ReLU(),
            nn.Flatten(), #256*8*8
            nn.Linear(256*8*8, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 10),
            nn.Softmax(1),
        )
    def forward(self, images):
        labels = self.net(images)
        return labels

In [21]:
C = Classification().cuda()

In [22]:
optimC = optim.Adam(C.parameters(),0.0002)

In [24]:
loss_fn = nn.CrossEntropyLoss().cuda()

In [25]:
from tqdm import tqdm

In [36]:
trained_epoch = 0
epochs = 400

In [None]:
for epoch in tqdm(range(epochs)):
    total_lossC = 0
    correct = 0
    for source_images, labels in loader:
        source_images = source_images.cuda()
        labels = labels.cuda()
        
        labels = nn.functional.one_hot(labels, 10).float()
        
        predictC = C(source_images)
        lossC = loss_fn(predictC, labels)
        
        optimC.zero_grad()
        lossC.backward()
        optimC.step()
        
        total_lossC += lossC * source_images.size(0)
        correct += (labels.argmax(1) == predictC.argmax(1)).sum()
        
    avg_lossC = total_lossC / len(dataset)
    accuracy = correct / len(dataset)
    print(f'C_loss: {avg_lossC} accuracy: {accuracy*100:.6f}%')
print(f'epoch {trained_epoch + epochs} complete')

In [39]:
torch.save(C.state_dict(), f'C_{trained_epoch + epochs}.pth')
torch.save(optimC.state_dict(), f'optimC_{trained_epoch + epochs}.pth')

In [47]:
import requests
from io import BytesIO

In [102]:
#test
url = 'https://www.rei.com/dam/van_dragt_102517_0032_how_to_choose_batteries.jpg'
headers = {
    'User-Agent': 'Chrome/123.0.0.0'
}
response = requests.get(url, headers=headers)
io = BytesIO(response.content)
image = Image.open(io)
image.save('battery.jpg')

In [None]:
test_epoch = 400

In [None]:
C = Classification().cuda()
C.load_state_dict(torch.load(f'C_{test_epoch}.pth'))

In [None]:
file = 'battery.jpg'
image = Image.open(file)
image = transform(image).view(1, 1, 256, 256).cuda()

C.eval()
with torch.no_grad():
    labels = C(image)[0]
    index = labels.argmax()
    print(labels[index].item())
    print(dataset.classes[index])
C = C.train()