In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import random, time, torchprofile
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Read and Transform Data
class LoadData(Dataset):
    def __init__(self, file, root):
        self.root = root
        
        # Data tranformation
        self.transform = transforms.Compose([
            transforms.Resize(84),
            transforms.CenterCrop(84),
            transforms.RandomHorizontalFlip(1),
            transforms.RandomVerticalFlip(0.1),
            transforms.RandomPerspective(distortion_scale = 0.2, p = 0.2),
            transforms.RandomRotation(15),
            RandomChannel(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        
        # Read the data file
        with open(file, 'r') as files:
            self.data = files.readlines()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        line = self.data[idx].strip()
        img_path, label = line.split()
        img_path = os.path.join(self.root, img_path)
        img = Image.open(img_path).convert('RGB')
        label = int(label)
        img = self.transform(img)
        return img, label

class RandomChannel:
    def __init__(self):
        self.all_channel = [
            (0, 1, 2),  # RGB
            (0, 1),     # RG
            (0, 2),     # RB
            (1, 2),     # GB
            (0,),       # R
            (1,),       # G
            (2,)        # B
        ]

    def __call__(self, img):
        channels = list(img.split())
        random_channel = random.choice(self.all_channel)
        img_merged = Image.merge('RGB', [channels[i] if i in random_channel else Image.new('L', img.size) for i in range(3)])
        return img_merged

class DynamicConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, bias = True):
        super(DynamicConv2D, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.bias = bias

        # Define convoluion weight and bias
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        
        # Adjust weight based on input channels
        channels = x.size(1)
        if channels < self.in_channels:
            weight = self.weight[:, :channels, :, :]
        else:
            weight = self.weight
        return F.conv2d(x, weight, self.bias, self.stride, self.padding)

## Train ResNet18 with DynamicConv2D

In [2]:
# Create dataset instances
train_dataset = LoadData(file = 'train.txt', root = '')
val_dataset = LoadData(file = 'val.txt', root = '')
test_dataset = LoadData(file = 'test.txt', root = '')

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 64, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size = 64, shuffle = False)

In [3]:
model = models.resnet18(pretrained = False)
model.conv1 = DynamicConv2D(in_channels = 3, out_channels = 64, kernel_size = 7, padding = 1)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.to(device)



ResNet(
  (conv1): DynamicConv2D()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [4]:
epochs = 30
lr = 0.1
momentum = 0.9
weight_decay = 1e-4
optimizer = optim.SGD(model.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)

def adjust_lr(optimizer, epoch):
    if epoch in [epochs * 0.5, epochs * 0.75, epochs * 0.85]:
        for p in optimizer.param_groups:
            p['lr'] *= 0.1
            lr = p['lr']
        print('Adjusted lr:' + str(lr))

def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    adjust_lr(optimizer, epoch)
    train_loader_len = len(train_loader.dataset)
    train_loader_iter = tqdm(enumerate(train_loader), total = len(train_loader), desc = "Training Epoch #{}".format(epoch))
    
    for batch_idx, (data, label) in train_loader_iter:
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, label)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim = True)[1]
        train_acc += pred.eq(label.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        train_loader_iter.set_postfix(loss = loss.item(), accuracy = 100. * train_acc.item() / train_loader_len)    
    print('Train Epoch: {}, Loss: {:.6f}, Accuracy: {:.2f}%'.format(epoch, avg_loss / len(train_loader), 100. * train_acc / train_loader_len))

def val(epoch):
    model.eval()
    test_loss = 0.
    correct = 0
    val_loader_iter = tqdm(val_loader, total = len(val_loader), desc = "Validation Epoch #{}".format(epoch))
    
    with torch.no_grad():
        for data, label in val_loader_iter:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, label, reduction = 'sum').item()
            pred = output.data.max(1, keepdim = True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            val_loader_iter.set_postfix(loss = test_loss / len(val_loader.dataset), accuracy = 100. * correct.item() / len(val_loader.dataset))
    
    test_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    print('Validation Set: Average Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, accuracy))
    return accuracy

best_val_acc = 0.
for i in range(epochs):
    train(i + 1)
    temp_acc = val(i + 1)
    if temp_acc > best_val_acc:
        best_val_acc = temp_acc
        torch.save(model.state_dict(), 'best_1.pt')
        print('Best Accuracy: {:.2f}%'.format(best_val_acc))
print('Final Best Accuracy: {:.2f}%'.format(best_val_acc))

Training Epoch #1: 100%|██████████| 990/990 [03:12<00:00,  5.14it/s, accuracy=4.58, loss=3.55] 


Train Epoch: 1, Loss: 3.811792, Accuracy: 4.58%


Validation Epoch #1: 100%|██████████| 8/8 [00:01<00:00,  6.56it/s, accuracy=10, loss=3.58]  


Validation Set: Average Loss: 3.5834, Accuracy: 10.00%
Best Accuracy: 10.00%


Training Epoch #2: 100%|██████████| 990/990 [03:10<00:00,  5.21it/s, accuracy=10.7, loss=2.97] 


Train Epoch: 2, Loss: 3.398106, Accuracy: 10.72%


Validation Epoch #2: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s, accuracy=10.2, loss=3.42] 


Validation Set: Average Loss: 3.4197, Accuracy: 10.22%
Best Accuracy: 10.22%


Training Epoch #3: 100%|██████████| 990/990 [03:10<00:00,  5.21it/s, accuracy=15.6, loss=3.32] 


Train Epoch: 3, Loss: 3.166864, Accuracy: 15.63%


Validation Epoch #3: 100%|██████████| 8/8 [00:01<00:00,  6.87it/s, accuracy=13.6, loss=3.19] 


Validation Set: Average Loss: 3.1871, Accuracy: 13.56%
Best Accuracy: 13.56%


Training Epoch #4: 100%|██████████| 990/990 [03:13<00:00,  5.12it/s, accuracy=20, loss=2.61]  


Train Epoch: 4, Loss: 2.961041, Accuracy: 19.99%


Validation Epoch #4: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s, accuracy=19.8, loss=2.99] 


Validation Set: Average Loss: 2.9865, Accuracy: 19.78%
Best Accuracy: 19.78%


Training Epoch #5: 100%|██████████| 990/990 [03:11<00:00,  5.16it/s, accuracy=25.2, loss=2.51]


Train Epoch: 5, Loss: 2.728448, Accuracy: 25.15%


Validation Epoch #5: 100%|██████████| 8/8 [00:01<00:00,  6.99it/s, accuracy=25.3, loss=2.69] 


Validation Set: Average Loss: 2.6858, Accuracy: 25.33%
Best Accuracy: 25.33%


Training Epoch #6: 100%|██████████| 990/990 [03:11<00:00,  5.17it/s, accuracy=29.6, loss=2.17]


Train Epoch: 6, Loss: 2.531418, Accuracy: 29.60%


Validation Epoch #6: 100%|██████████| 8/8 [00:01<00:00,  6.74it/s, accuracy=26.7, loss=2.61] 


Validation Set: Average Loss: 2.6096, Accuracy: 26.67%
Best Accuracy: 26.67%


Training Epoch #7: 100%|██████████| 990/990 [03:10<00:00,  5.21it/s, accuracy=33.7, loss=2.62]


Train Epoch: 7, Loss: 2.357561, Accuracy: 33.67%


Validation Epoch #7: 100%|██████████| 8/8 [00:01<00:00,  6.89it/s, accuracy=26.9, loss=2.73] 


Validation Set: Average Loss: 2.7325, Accuracy: 26.89%
Best Accuracy: 26.89%


Training Epoch #8: 100%|██████████| 990/990 [03:10<00:00,  5.21it/s, accuracy=36.9, loss=2.32]


Train Epoch: 8, Loss: 2.220832, Accuracy: 36.87%


Validation Epoch #8: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s, accuracy=30.7, loss=2.53] 


Validation Set: Average Loss: 2.5290, Accuracy: 30.67%
Best Accuracy: 30.67%


Training Epoch #9: 100%|██████████| 990/990 [03:08<00:00,  5.25it/s, accuracy=39.7, loss=2.88]


Train Epoch: 9, Loss: 2.111002, Accuracy: 39.73%


Validation Epoch #9: 100%|██████████| 8/8 [00:01<00:00,  6.89it/s, accuracy=31.6, loss=2.53] 


Validation Set: Average Loss: 2.5331, Accuracy: 31.56%
Best Accuracy: 31.56%


Training Epoch #10: 100%|██████████| 990/990 [03:06<00:00,  5.30it/s, accuracy=42, loss=1.88]  


Train Epoch: 10, Loss: 2.011834, Accuracy: 42.03%


Validation Epoch #10: 100%|██████████| 8/8 [00:01<00:00,  6.90it/s, accuracy=35.3, loss=2.39] 


Validation Set: Average Loss: 2.3908, Accuracy: 35.33%
Best Accuracy: 35.33%


Training Epoch #11: 100%|██████████| 990/990 [03:10<00:00,  5.20it/s, accuracy=44.3, loss=2.12]


Train Epoch: 11, Loss: 1.927272, Accuracy: 44.30%


Validation Epoch #11: 100%|██████████| 8/8 [00:01<00:00,  6.72it/s, accuracy=39.1, loss=2.25] 


Validation Set: Average Loss: 2.2468, Accuracy: 39.11%
Best Accuracy: 39.11%


Training Epoch #12: 100%|██████████| 990/990 [03:12<00:00,  5.14it/s, accuracy=46, loss=2.19]  


Train Epoch: 12, Loss: 1.852095, Accuracy: 46.03%


Validation Epoch #12: 100%|██████████| 8/8 [00:01<00:00,  6.73it/s, accuracy=38.2, loss=2.15] 


Validation Set: Average Loss: 2.1513, Accuracy: 38.22%


Training Epoch #13: 100%|██████████| 990/990 [03:10<00:00,  5.18it/s, accuracy=47.8, loss=1.53]


Train Epoch: 13, Loss: 1.786854, Accuracy: 47.76%


Validation Epoch #13: 100%|██████████| 8/8 [00:01<00:00,  6.69it/s, accuracy=38.7, loss=2.27] 


Validation Set: Average Loss: 2.2728, Accuracy: 38.67%


Training Epoch #14: 100%|██████████| 990/990 [03:08<00:00,  5.26it/s, accuracy=49.6, loss=1.83]


Train Epoch: 14, Loss: 1.724606, Accuracy: 49.56%


Validation Epoch #14: 100%|██████████| 8/8 [00:01<00:00,  6.88it/s, accuracy=35.1, loss=2.51] 


Validation Set: Average Loss: 2.5074, Accuracy: 35.11%
Adjusted lr:0.010000000000000002


Training Epoch #15: 100%|██████████| 990/990 [03:07<00:00,  5.28it/s, accuracy=61.6, loss=1]    


Train Epoch: 15, Loss: 1.299005, Accuracy: 61.56%


Validation Epoch #15: 100%|██████████| 8/8 [00:01<00:00,  6.71it/s, accuracy=48.9, loss=1.77] 


Validation Set: Average Loss: 1.7658, Accuracy: 48.89%
Best Accuracy: 48.89%


Training Epoch #16: 100%|██████████| 990/990 [03:08<00:00,  5.24it/s, accuracy=65.2, loss=0.86] 


Train Epoch: 16, Loss: 1.158642, Accuracy: 65.21%


Validation Epoch #16: 100%|██████████| 8/8 [00:01<00:00,  6.66it/s, accuracy=50.7, loss=1.69] 


Validation Set: Average Loss: 1.6905, Accuracy: 50.67%
Best Accuracy: 50.67%


Training Epoch #17: 100%|██████████| 990/990 [03:09<00:00,  5.22it/s, accuracy=67.1, loss=1.14] 


Train Epoch: 17, Loss: 1.087204, Accuracy: 67.14%


Validation Epoch #17: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s, accuracy=55.1, loss=1.64] 


Validation Set: Average Loss: 1.6441, Accuracy: 55.11%
Best Accuracy: 55.11%


Training Epoch #18: 100%|██████████| 990/990 [03:12<00:00,  5.15it/s, accuracy=68.8, loss=1.03] 


Train Epoch: 18, Loss: 1.029867, Accuracy: 68.82%


Validation Epoch #18: 100%|██████████| 8/8 [00:01<00:00,  6.60it/s, accuracy=51.8, loss=1.75] 


Validation Set: Average Loss: 1.7514, Accuracy: 51.78%


Training Epoch #19: 100%|██████████| 990/990 [03:10<00:00,  5.21it/s, accuracy=70.4, loss=0.729]


Train Epoch: 19, Loss: 0.971468, Accuracy: 70.38%


Validation Epoch #19: 100%|██████████| 8/8 [00:01<00:00,  6.61it/s, accuracy=50.9, loss=1.74] 


Validation Set: Average Loss: 1.7382, Accuracy: 50.89%


Training Epoch #20: 100%|██████████| 990/990 [03:09<00:00,  5.23it/s, accuracy=72.2, loss=0.878]


Train Epoch: 20, Loss: 0.912340, Accuracy: 72.23%


Validation Epoch #20: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s, accuracy=52.2, loss=1.73] 


Validation Set: Average Loss: 1.7251, Accuracy: 52.22%


Training Epoch #21: 100%|██████████| 990/990 [03:08<00:00,  5.25it/s, accuracy=73.5, loss=0.794]


Train Epoch: 21, Loss: 0.868522, Accuracy: 73.53%


Validation Epoch #21: 100%|██████████| 8/8 [00:01<00:00,  6.63it/s, accuracy=54.7, loss=1.72] 


Validation Set: Average Loss: 1.7200, Accuracy: 54.67%


Training Epoch #22: 100%|██████████| 990/990 [03:10<00:00,  5.19it/s, accuracy=75, loss=0.983]  


Train Epoch: 22, Loss: 0.820496, Accuracy: 75.02%


Validation Epoch #22: 100%|██████████| 8/8 [00:01<00:00,  6.79it/s, accuracy=49.3, loss=1.82] 


Validation Set: Average Loss: 1.8182, Accuracy: 49.33%


Training Epoch #23: 100%|██████████| 990/990 [03:12<00:00,  5.14it/s, accuracy=76.6, loss=0.727]


Train Epoch: 23, Loss: 0.774457, Accuracy: 76.59%


Validation Epoch #23: 100%|██████████| 8/8 [00:01<00:00,  6.69it/s, accuracy=52, loss=1.81]   


Validation Set: Average Loss: 1.8141, Accuracy: 52.00%


Training Epoch #24: 100%|██████████| 990/990 [03:11<00:00,  5.18it/s, accuracy=77.9, loss=0.615]


Train Epoch: 24, Loss: 0.731243, Accuracy: 77.88%


Validation Epoch #24: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s, accuracy=51.3, loss=1.84] 


Validation Set: Average Loss: 1.8445, Accuracy: 51.33%


Training Epoch #25: 100%|██████████| 990/990 [03:05<00:00,  5.35it/s, accuracy=78.9, loss=0.659]


Train Epoch: 25, Loss: 0.698982, Accuracy: 78.86%


Validation Epoch #25: 100%|██████████| 8/8 [00:01<00:00,  6.58it/s, accuracy=51.3, loss=1.84] 


Validation Set: Average Loss: 1.8417, Accuracy: 51.33%


Training Epoch #26: 100%|██████████| 990/990 [03:09<00:00,  5.22it/s, accuracy=80.5, loss=0.637]


Train Epoch: 26, Loss: 0.660659, Accuracy: 80.47%


Validation Epoch #26: 100%|██████████| 8/8 [00:01<00:00,  6.77it/s, accuracy=51.3, loss=1.93] 


Validation Set: Average Loss: 1.9322, Accuracy: 51.33%


Training Epoch #27: 100%|██████████| 990/990 [03:08<00:00,  5.26it/s, accuracy=81.1, loss=0.642]


Train Epoch: 27, Loss: 0.632375, Accuracy: 81.10%


Validation Epoch #27: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s, accuracy=51.6, loss=1.93] 


Validation Set: Average Loss: 1.9341, Accuracy: 51.56%


Training Epoch #28: 100%|██████████| 990/990 [03:10<00:00,  5.20it/s, accuracy=81.9, loss=0.672]


Train Epoch: 28, Loss: 0.599173, Accuracy: 81.93%


Validation Epoch #28: 100%|██████████| 8/8 [00:01<00:00,  6.72it/s, accuracy=49.6, loss=2.1]  


Validation Set: Average Loss: 2.0969, Accuracy: 49.56%


Training Epoch #29: 100%|██████████| 990/990 [03:10<00:00,  5.20it/s, accuracy=82.8, loss=0.402]


Train Epoch: 29, Loss: 0.579697, Accuracy: 82.82%


Validation Epoch #29: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s, accuracy=50.4, loss=2.04] 


Validation Set: Average Loss: 2.0369, Accuracy: 50.44%


Training Epoch #30: 100%|██████████| 990/990 [03:10<00:00,  5.20it/s, accuracy=83.6, loss=0.484]


Train Epoch: 30, Loss: 0.555541, Accuracy: 83.61%


Validation Epoch #30: 100%|██████████| 8/8 [00:01<00:00,  6.93it/s, accuracy=48.2, loss=2.11] 

Validation Set: Average Loss: 2.1070, Accuracy: 48.22%
Final Best Accuracy: 55.11%





## Load ResNet18 with DynamicConv2D

In [5]:
model = models.resnet18(pretrained = False)
model.conv1 = DynamicConv2D(in_channels = 3, out_channels = 64, kernel_size = 7, padding = 1)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.load_state_dict(torch.load("best_1.pt"))
model.to(device)
model.eval()

ResNet(
  (conv1): DynamicConv2D()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

## Evaluate by Testing Dataset

In [6]:
def test(model):
    model.eval()
    correct = 0
    test_dataset = LoadData(file = 'test.txt', root = '')
    test_loader = DataLoader(test_dataset, batch_size = 64, shuffle = False)
    test_loader_len = len(test_loader.dataset)
    test_loader_iter = tqdm(test_loader, total = len(test_loader), desc = "Testing")
    
    all_preds = []
    all_targets = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for data, label in test_loader_iter:
            data, label = data.to(device), label.to(device)
            output = model(data)
            pred = output.data.max(1, keepdim = True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(label.cpu().numpy())
            
            test_loader_iter.set_postfix(accuracy = 100. * correct.item() / test_loader_len)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    accuracy = 100. * correct / test_loader_len
    

    # Calculate precision, recall, and F1-score
    precision = 100. * precision_score(all_targets, all_preds, average = 'macro')
    recall = 100. * recall_score(all_targets, all_preds, average = 'macro')
    f1 = 100. * f1_score(all_targets, all_preds, average = 'macro')
    
    # Calculate FLOPS
    flops = torchprofile.profile_macs(model, torch.randn(1, *data.shape[1:]).to(device))
    
    return accuracy.item(), precision, recall, f1, flops, elapsed_time

rgb_list = ["RGB", "RG", "RB", "GB", "R", "G", "B"]
for rgb in rgb_list:
    resnet_acc, resnet_precision, resnet_recall, resnet_f1, resnet_flops, resnet_elapsed_time = test(model)
    print(f"RGB Set: {rgb:s}, Accuracy: {resnet_acc:.2f}%, Precision: {resnet_precision:.2f}%, Recall: {resnet_recall:.2f}%, F1 Score: {resnet_f1:.2f}%, FLOPS: {resnet_flops:d}, Elapsed Time: {resnet_elapsed_time:.2f} seconds")

Testing: 100%|██████████| 8/8 [00:01<00:00,  6.80it/s, accuracy=54.2]


RGB Set: RGB, Accuracy: 54.22%, Precision: 56.11%, Recall: 54.22%, F1 Score: 53.69%, FLOPS: 926579712, Elapsed Time: 1.18 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  6.80it/s, accuracy=53.6]


RGB Set: RG, Accuracy: 53.56%, Precision: 53.97%, Recall: 53.56%, F1 Score: 52.56%, FLOPS: 926579712, Elapsed Time: 1.18 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  6.80it/s, accuracy=52]  


RGB Set: RB, Accuracy: 52.00%, Precision: 52.44%, Recall: 52.00%, F1 Score: 50.86%, FLOPS: 926579712, Elapsed Time: 1.18 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  6.70it/s, accuracy=54.2]


RGB Set: GB, Accuracy: 54.22%, Precision: 56.55%, Recall: 54.22%, F1 Score: 54.08%, FLOPS: 926579712, Elapsed Time: 1.19 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s, accuracy=50]  


RGB Set: R, Accuracy: 50.00%, Precision: 51.50%, Recall: 50.00%, F1 Score: 48.99%, FLOPS: 926579712, Elapsed Time: 1.18 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  6.80it/s, accuracy=52.7]


RGB Set: G, Accuracy: 52.67%, Precision: 54.41%, Recall: 52.67%, F1 Score: 51.93%, FLOPS: 926579712, Elapsed Time: 1.18 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s, accuracy=52]  


RGB Set: B, Accuracy: 52.00%, Precision: 54.26%, Recall: 52.00%, F1 Score: 51.44%, FLOPS: 926579712, Elapsed Time: 1.20 seconds


## Train ResNet18

In [7]:
model = models.resnet18(pretrained = False)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.to(device)



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [8]:
epochs = 30
lr = 0.1
momentum = 0.9
weight_decay = 1e-4
optimizer = optim.SGD(model.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)

def adjust_lr(optimizer, epoch):
    if epoch in [epochs * 0.5, epochs * 0.75, epochs * 0.85]:
        for p in optimizer.param_groups:
            p['lr'] *= 0.1
            lr = p['lr']
        print('Adjusted lr:' + str(lr))

def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    adjust_lr(optimizer, epoch)
    train_loader_len = len(train_loader.dataset)
    train_loader_iter = tqdm(enumerate(train_loader), total = len(train_loader), desc = "Training Epoch #{}".format(epoch))
    
    for batch_idx, (data, label) in train_loader_iter:
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, label)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim = True)[1]
        train_acc += pred.eq(label.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        train_loader_iter.set_postfix(loss = loss.item(), accuracy = 100. * train_acc.item() / train_loader_len)
    print('Train Epoch: {}, Loss: {:.6f}, Accuracy: {:.2f}%'.format(epoch, avg_loss / len(train_loader), 100. * train_acc / train_loader_len))

def val(epoch):
    model.eval()
    test_loss = 0.
    correct = 0
    val_loader_iter = tqdm(val_loader, total = len(val_loader), desc = "Validation Epoch #{}".format(epoch))
    
    with torch.no_grad():
        for data, label in val_loader_iter:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, label, reduction = 'sum').item()
            pred = output.data.max(1, keepdim = True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            val_loader_iter.set_postfix(loss = test_loss / len(val_loader.dataset), accuracy = 100. * correct.item() / len(val_loader.dataset))
    
    test_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    print('Validation Set: Average Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, accuracy))
    return accuracy
    
best_val_acc = 0.
for i in range(epochs):
    train(i + 1)
    temp_acc = val(i + 1)
    if temp_acc > best_val_acc:
        best_val_acc = temp_acc
        torch.save(model.state_dict(), 'resnet18_best.pt')
        print('Best Accuracy: {:.2f}%'.format(best_val_acc))
print('Final Best Accuracy: {:.2f}%'.format(best_val_acc))

Training Epoch #1: 100%|██████████| 990/990 [02:51<00:00,  5.77it/s, accuracy=4.09, loss=3.71] 


Train Epoch: 1, Loss: 3.841783, Accuracy: 4.09%


Validation Epoch #1: 100%|██████████| 8/8 [00:01<00:00,  6.75it/s, accuracy=6, loss=3.64]    


Validation Set: Average Loss: 3.6436, Accuracy: 6.00%
Best Accuracy: 6.00%


Training Epoch #2: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=7.14, loss=3.55] 


Train Epoch: 2, Loss: 3.622614, Accuracy: 7.14%


Validation Epoch #2: 100%|██████████| 8/8 [00:01<00:00,  7.06it/s, accuracy=10.4, loss=3.54] 


Validation Set: Average Loss: 3.5381, Accuracy: 10.44%
Best Accuracy: 10.44%


Training Epoch #3: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=10.3, loss=3.85] 


Train Epoch: 3, Loss: 3.464272, Accuracy: 10.29%


Validation Epoch #3: 100%|██████████| 8/8 [00:01<00:00,  6.85it/s, accuracy=10.4, loss=3.44]  


Validation Set: Average Loss: 3.4352, Accuracy: 10.44%


Training Epoch #4: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=13.5, loss=3.3]  


Train Epoch: 4, Loss: 3.298670, Accuracy: 13.49%


Validation Epoch #4: 100%|██████████| 8/8 [00:01<00:00,  6.94it/s, accuracy=13.8, loss=3.42] 


Validation Set: Average Loss: 3.4183, Accuracy: 13.78%
Best Accuracy: 13.78%


Training Epoch #5: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=16.2, loss=2.98]


Train Epoch: 5, Loss: 3.164157, Accuracy: 16.16%


Validation Epoch #5: 100%|██████████| 8/8 [00:01<00:00,  7.00it/s, accuracy=18.7, loss=3.24] 


Validation Set: Average Loss: 3.2370, Accuracy: 18.67%
Best Accuracy: 18.67%


Training Epoch #6: 100%|██████████| 990/990 [02:53<00:00,  5.72it/s, accuracy=19, loss=3.02]  


Train Epoch: 6, Loss: 3.027364, Accuracy: 19.02%


Validation Epoch #6: 100%|██████████| 8/8 [00:01<00:00,  7.05it/s, accuracy=15.3, loss=3.38] 


Validation Set: Average Loss: 3.3764, Accuracy: 15.33%


Training Epoch #7: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=21.7, loss=3.09]


Train Epoch: 7, Loss: 2.901107, Accuracy: 21.66%


Validation Epoch #7: 100%|██████████| 8/8 [00:01<00:00,  7.17it/s, accuracy=18.4, loss=3.12] 


Validation Set: Average Loss: 3.1183, Accuracy: 18.44%


Training Epoch #8: 100%|██████████| 990/990 [02:53<00:00,  5.72it/s, accuracy=24, loss=3.34]  


Train Epoch: 8, Loss: 2.797129, Accuracy: 23.98%


Validation Epoch #8: 100%|██████████| 8/8 [00:01<00:00,  7.04it/s, accuracy=20.7, loss=2.98]


Validation Set: Average Loss: 2.9762, Accuracy: 20.67%
Best Accuracy: 20.67%


Training Epoch #9: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=26.5, loss=2.81]


Train Epoch: 9, Loss: 2.691924, Accuracy: 26.54%


Validation Epoch #9: 100%|██████████| 8/8 [00:01<00:00,  6.99it/s, accuracy=18, loss=3.27]   


Validation Set: Average Loss: 3.2705, Accuracy: 18.00%


Training Epoch #10: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=28.5, loss=2.69]


Train Epoch: 10, Loss: 2.605707, Accuracy: 28.53%


Validation Epoch #10: 100%|██████████| 8/8 [00:01<00:00,  7.09it/s, accuracy=24.4, loss=2.79] 


Validation Set: Average Loss: 2.7882, Accuracy: 24.44%
Best Accuracy: 24.44%


Training Epoch #11: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=30.5, loss=3]   


Train Epoch: 11, Loss: 2.513652, Accuracy: 30.50%


Validation Epoch #11: 100%|██████████| 8/8 [00:01<00:00,  6.96it/s, accuracy=22.9, loss=3.06] 


Validation Set: Average Loss: 3.0596, Accuracy: 22.89%


Training Epoch #12: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=32.3, loss=2.57]


Train Epoch: 12, Loss: 2.436449, Accuracy: 32.30%


Validation Epoch #12: 100%|██████████| 8/8 [00:01<00:00,  7.03it/s, accuracy=26.2, loss=2.81] 


Validation Set: Average Loss: 2.8121, Accuracy: 26.22%
Best Accuracy: 26.22%


Training Epoch #13: 100%|██████████| 990/990 [02:54<00:00,  5.69it/s, accuracy=34.5, loss=2.55]


Train Epoch: 13, Loss: 2.350140, Accuracy: 34.46%


Validation Epoch #13: 100%|██████████| 8/8 [00:01<00:00,  6.92it/s, accuracy=24.7, loss=2.84] 


Validation Set: Average Loss: 2.8418, Accuracy: 24.67%


Training Epoch #14: 100%|██████████| 990/990 [02:53<00:00,  5.69it/s, accuracy=36.1, loss=2.25]


Train Epoch: 14, Loss: 2.274475, Accuracy: 36.15%


Validation Epoch #14: 100%|██████████| 8/8 [00:01<00:00,  7.01it/s, accuracy=26.4, loss=2.8]  


Validation Set: Average Loss: 2.8005, Accuracy: 26.44%
Best Accuracy: 26.44%
Adjusted lr:0.010000000000000002


Training Epoch #15: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=47.2, loss=1.47]


Train Epoch: 15, Loss: 1.852447, Accuracy: 47.15%


Validation Epoch #15: 100%|██████████| 8/8 [00:01<00:00,  7.07it/s, accuracy=36, loss=2.27]   


Validation Set: Average Loss: 2.2663, Accuracy: 36.00%
Best Accuracy: 36.00%


Training Epoch #16: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=51.3, loss=2.37]


Train Epoch: 16, Loss: 1.695119, Accuracy: 51.30%


Validation Epoch #16: 100%|██████████| 8/8 [00:01<00:00,  6.92it/s, accuracy=35.6, loss=2.38] 


Validation Set: Average Loss: 2.3775, Accuracy: 35.56%


Training Epoch #17: 100%|██████████| 990/990 [02:53<00:00,  5.72it/s, accuracy=53.3, loss=2.21]


Train Epoch: 17, Loss: 1.609303, Accuracy: 53.31%


Validation Epoch #17: 100%|██████████| 8/8 [00:01<00:00,  7.00it/s, accuracy=34.4, loss=2.39] 


Validation Set: Average Loss: 2.3894, Accuracy: 34.44%


Training Epoch #18: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=55.6, loss=1.85] 


Train Epoch: 18, Loss: 1.530552, Accuracy: 55.63%


Validation Epoch #18: 100%|██████████| 8/8 [00:01<00:00,  6.85it/s, accuracy=37.1, loss=2.34] 


Validation Set: Average Loss: 2.3381, Accuracy: 37.11%
Best Accuracy: 37.11%


Training Epoch #19: 100%|██████████| 990/990 [02:54<00:00,  5.69it/s, accuracy=57.7, loss=1.27] 


Train Epoch: 19, Loss: 1.445913, Accuracy: 57.74%


Validation Epoch #19: 100%|██████████| 8/8 [00:01<00:00,  6.99it/s, accuracy=36.4, loss=2.39] 


Validation Set: Average Loss: 2.3884, Accuracy: 36.44%


Training Epoch #20: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=59.7, loss=1.69] 


Train Epoch: 20, Loss: 1.382534, Accuracy: 59.67%


Validation Epoch #20: 100%|██████████| 8/8 [00:01<00:00,  6.82it/s, accuracy=36, loss=2.45]   


Validation Set: Average Loss: 2.4470, Accuracy: 36.00%


Training Epoch #21: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=61.7, loss=1.81] 


Train Epoch: 21, Loss: 1.314273, Accuracy: 61.72%


Validation Epoch #21: 100%|██████████| 8/8 [00:01<00:00,  6.97it/s, accuracy=34, loss=2.51]   


Validation Set: Average Loss: 2.5071, Accuracy: 34.00%


Training Epoch #22: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=63.8, loss=1.16] 


Train Epoch: 22, Loss: 1.240137, Accuracy: 63.84%


Validation Epoch #22: 100%|██████████| 8/8 [00:01<00:00,  7.02it/s, accuracy=35.3, loss=2.48] 


Validation Set: Average Loss: 2.4760, Accuracy: 35.33%


Training Epoch #23: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=66.2, loss=1.46] 


Train Epoch: 23, Loss: 1.165202, Accuracy: 66.16%


Validation Epoch #23: 100%|██████████| 8/8 [00:01<00:00,  6.90it/s, accuracy=34.7, loss=2.49] 


Validation Set: Average Loss: 2.4865, Accuracy: 34.67%


Training Epoch #24: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=67.8, loss=1.45] 


Train Epoch: 24, Loss: 1.110179, Accuracy: 67.78%


Validation Epoch #24: 100%|██████████| 8/8 [00:01<00:00,  6.81it/s, accuracy=35.8, loss=2.53] 


Validation Set: Average Loss: 2.5263, Accuracy: 35.78%


Training Epoch #25: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=69.3, loss=1.15] 


Train Epoch: 25, Loss: 1.056775, Accuracy: 69.28%


Validation Epoch #25: 100%|██████████| 8/8 [00:01<00:00,  6.94it/s, accuracy=32.2, loss=2.68] 


Validation Set: Average Loss: 2.6752, Accuracy: 32.22%


Training Epoch #26: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=71.3, loss=1.27] 


Train Epoch: 26, Loss: 0.994045, Accuracy: 71.34%


Validation Epoch #26: 100%|██████████| 8/8 [00:01<00:00,  7.01it/s, accuracy=33.1, loss=2.82] 


Validation Set: Average Loss: 2.8245, Accuracy: 33.11%


Training Epoch #27: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=73, loss=0.893]  


Train Epoch: 27, Loss: 0.943668, Accuracy: 73.01%


Validation Epoch #27: 100%|██████████| 8/8 [00:01<00:00,  6.97it/s, accuracy=31.3, loss=2.77] 


Validation Set: Average Loss: 2.7693, Accuracy: 31.33%


Training Epoch #28: 100%|██████████| 990/990 [02:53<00:00,  5.71it/s, accuracy=74.5, loss=0.922]


Train Epoch: 28, Loss: 0.894272, Accuracy: 74.53%


Validation Epoch #28: 100%|██████████| 8/8 [00:01<00:00,  6.98it/s, accuracy=32.7, loss=2.75] 


Validation Set: Average Loss: 2.7545, Accuracy: 32.67%


Training Epoch #29: 100%|██████████| 990/990 [02:53<00:00,  5.70it/s, accuracy=75.4, loss=0.812]


Train Epoch: 29, Loss: 0.862565, Accuracy: 75.40%


Validation Epoch #29: 100%|██████████| 8/8 [00:01<00:00,  6.97it/s, accuracy=33.3, loss=2.82] 


Validation Set: Average Loss: 2.8161, Accuracy: 33.33%


Training Epoch #30: 100%|██████████| 990/990 [02:53<00:00,  5.69it/s, accuracy=76.7, loss=0.884]


Train Epoch: 30, Loss: 0.822482, Accuracy: 76.70%


Validation Epoch #30: 100%|██████████| 8/8 [00:01<00:00,  7.07it/s, accuracy=32.2, loss=3.03] 

Validation Set: Average Loss: 3.0281, Accuracy: 32.22%
Final Best Accuracy: 37.11%





## Load ResNet18

In [9]:
resnet18_model = models.resnet18(pretrained = False)
resnet18_model.fc = torch.nn.Linear(model.fc.in_features, 50)
resnet18_model.load_state_dict(torch.load("resnet18_best.pt"))
resnet18_model.to(device)
resnet18_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## Evaluate by testing dataset

In [10]:
test_dataset = LoadData(file = 'test.txt', root = '')
test_loader = DataLoader(test_dataset, batch_size = 64, shuffle = False)
def test(model):
    model.eval()
    correct = 0
    test_loader_len = len(test_loader.dataset)
    test_loader_iter = tqdm(test_loader, total = len(test_loader), desc = "Testing")
    all_preds = []
    all_targets = []
    start_time = time.time()
    
    with torch.no_grad():
        for data, label in test_loader_iter:
            data, label = data.to(device), label.to(device)
            output = model(data)
            pred = output.data.max(1, keepdim = True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(label.cpu().numpy())
            test_loader_iter.set_postfix(accuracy = 100. * correct.item() / test_loader_len)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    accuracy = 100. * correct / test_loader_len

    # Calculate precision, recall, and F1-score
    precision = 100. * precision_score(all_targets, all_preds, average = 'macro')
    recall = 100. * recall_score(all_targets, all_preds, average = 'macro')
    f1 = 100. * f1_score(all_targets, all_preds, average = 'macro')
    
    # Calculate FLOPS
    flops = torchprofile.profile_macs(model, torch.randn(1, *data.shape[1:]).to(device))
    return accuracy.item(), precision, recall, f1, flops, elapsed_time

resnet_acc, resnet_precision, resnet_recall, resnet_f1, resnet_flops, resnet_elapsed_time = test(resnet18_model)
print(f"Accuracy: {resnet_acc:.2f}%, Precision: {resnet_precision:.2f}%, Recall: {resnet_recall:.2f}%, F1 Score: {resnet_f1:.2f}%, FLOPS: {resnet_flops:d}, Elapsed Time: {resnet_elapsed_time:.2f} seconds")

Testing: 100%|██████████| 8/8 [00:01<00:00,  7.05it/s, accuracy=38.4]


Accuracy: 38.44%, Precision: 38.91%, Recall: 38.44%, F1 Score: 37.66%, FLOPS: 296456064, Elapsed Time: 1.14 seconds
