In [1]:
import torch
import torch.nn as nn

In [7]:
class block(nn.Module):
    def __init__(self, in_channels, intermediate_channels, identity_downsample=None, stride=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, intermediate_channels, 1, 1, bias=False),
            nn.BatchNorm2d(intermediate_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(intermediate_channels, intermediate_channels, 3, stride, 1, bias=False),
            nn.BatchNorm2d(intermediate_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(intermediate_channels, intermediate_channels * 4, 1, 1, bias=False),
            nn.BatchNorm2d(intermediate_channels * 4)
        )
        self.identity_downsample = identity_downsample
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x.clone()

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x = self.conv(x)
        x = x + identity

        return self.relu(x)

In [8]:
class ResNet50(nn.Module):
    def __init__(self, block, in_channels, num_classes, features = [3, 4, 6, 3]):
        super().__init__()
        self.in_channels = 64
        self.num_classes = num_classes

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)
        )

        self.layers2 = self._make_conv_layers(block, 64, features[0], 1)
        self.layers3 = self._make_conv_layers(block, 128, features[1], 2)
        self.layers4 = self._make_conv_layers(block, 256, features[2], 2)
        self.layers5 = self._make_conv_layers(block, 512, features[3], 2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.conv(x)
        x = self.layers2(x)
        x = self.layers3(x)
        x = self.layers4(x)
        x = self.layers5(x)
        x = self.avg_pool(x)
        x = x.reshape(x.shape[0], -1)

        return self.fc(x)

    def _make_conv_layers(self, block, intermediate_channels, num_repeats, stride):
        layers = []
        identity = None

        if stride == 2 or self.in_channels != intermediate_channels * 4:
            identity = nn.Sequential(
                nn.Conv2d(self.in_channels, intermediate_channels * 4, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(intermediate_channels * 4)
            )

        layers.append(block(self.in_channels, intermediate_channels, identity, stride))
        self.in_channels = intermediate_channels * 4

        for _ in range(num_repeats - 1):
            layers.append(block(self.in_channels, intermediate_channels))

        return nn.Sequential(*layers)

In [10]:
x = torch.randn(32, 3, 224, 224)
model = ResNet50(block, 3, 1000)
print(model(x).shape)

torch.Size([32, 1000])


In [12]:
import os
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import numpy as np 
from collections import defaultdict

In [13]:
class PlantDataset(Dataset):
    def __init__(self, img_dir, csv_file, transform=None):
        super().__init__()
        self.__annotations__ = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        df = self.__annotations__
        dct = defaultdict(list)
        
        for i, label in enumerate(df.labels):
            for category in label.split():
                dct[category].append(i)

        dct = {key: np.array(val) for key, val in dct.items()}
        new_df = pd.DataFrame(np.zeros((len(df), len(dct.keys())), dtype=np.int8), columns=dct.keys())

        for key, val in new_df.items():
            new_df.loc[val, key] = 1

        self.new_df = new_df

    def __len__(self):
        return len(self.__annotations__)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.__annotations__.iloc[index, 0])
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        label = torch.tensor(self.new_df.iloc[index, :], dtype=torch.float32)

        return img, label

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 0.001
batch_size = 32

In [20]:
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms

In [21]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [None]:
train_dataset = PlantDataset(img_dir='images', csv_file='fuck', transform=transform)
test_dataset = PlantDataset(img_dir='images', csv_file='fuck', transform=transform)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

In [16]:
model = ResNet50(block, 3, 1000).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

In [17]:
from tqdm.auto import tqdm

In [18]:
num_epochs = 100

In [None]:
for epoch in tqdm(range(num_epochs)):
    train_loss = 0
    train_acc = 0
    num_correct = 0
    total_sample = 0

    for batch_idx, (img, label) in enumerate(tqdm(train_loader)):
        img, label = img.to(device), label.to(device)

        logits = model(img).squeeze()
        loss = loss_fn(logits, label)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predictions = torch.round(torch.sigmoid(logits))
        num_correct += (predictions == label).all().sum().item()
        total_sample += label.shape[0]

    train_loss /= len(train_loader)
    train_acc = num_correct / total_sample

    print(f"epoch: {epoch}/{num_epochs} \n train_loss: {train_loss} train_acc: {train_acc} \n")

    model.eval()
    test_loss = 0
    test_acc = 0
    num_correct = 0
    total_sample = 0

    with torch.no_grad():
        for batch_idx, (img, label) in enumerate(test_loader):
            img, label = img.to(device), label.to(device)

            logits = model(img).squeeze()
            loss = loss_fn(logits, label)
            test_loss += loss.item()

            predictions = torch.round(torch.sigmoid(logits))
            num_correct += (predictions == label).all().sum().item()
            total_sample += label.shape[0]

        test_loss /= len(test_loader)
        test_acc = num_correct / total_sample

        print(f"test_loss: {test_loss} test_acc: {test_acc}")