In [1]:
# %% [markdown]
# # 改進後的動態 ResNet18 模型訓練與多通道推論
# 此 Notebook 實現：
# 1. 使用同資料夾中的 images 與 txt 文件（train.txt、val.txt、test.txt）建立自訂 Dataset  
# 2. 定義改進後的動態卷積及動態 ResNet18 模型  
#    - 延緩溫度退火：初始溫度設為 31，每個 epoch 降 1（最低保持 1）  
#    - 在全連接層前加入 Dropout (p=0.1)  
# 3. 訓練過程中存下每個 epoch 的訓練與驗證結果  
# 4. 訓練後對測試集進行各種通道組合（RGB, RG, RB, GB, R, G, B）下的推論，結果存檔  
# 5. 訓練結束後顯示論文表格：包含訓練/驗證曲線、通道推論結果，以及模型參數與 FLOPS  

# %% [code]
import os
import csv
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from torchvision import transforms

# %% [markdown]
# ## 1. 自訂 Dataset 讀取 txt 文件  
# 每行格式："images/xxx.JPEG label"

# %% [code]
class TxtImageDataset(Dataset):
    def __init__(self, txt_file, root_dir, transform=None):
        self.samples = []
        with open(txt_file, 'r') as f:
            for line in f:
                line = line.strip()
                if not line: continue
                path, label = line.split()
                self.samples.append((path, int(label)))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        full_path = os.path.join(self.root_dir, img_path)
        image = Image.open(full_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# %% [markdown]
# ## 2. 定義動態卷積模組及改進後的動態 ResNet18 模型

# %% [code]
class attention2d(nn.Module):
    def __init__(self, in_planes, ratio, K, temperature, init_weight=True):
        super(attention2d, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        hidden_planes = int(in_planes * ratio) + 1 if in_planes != 3 else K
        self.fc1 = nn.Conv2d(in_planes, hidden_planes, kernel_size=1, bias=False)
        self.fc2 = nn.Conv2d(hidden_planes, K, kernel_size=1, bias=True)
        self.temperature = temperature
        if init_weight:
            self._initialize_weights()
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def update_temperature(self):
        if self.temperature > 1:
            self.temperature -= 1
            if self.temperature < 1:
                self.temperature = 1
            print(f"Attention temperature -> {self.temperature}")
    def forward(self, x):
        x = self.avgpool(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x).view(x.size(0), -1)
        return F.softmax(x / self.temperature, dim=1)

class Dynamic_conv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, K=4, temperature=31, init_weight=True):
        super(Dynamic_conv2d, self).__init__()
        assert in_planes % groups == 0
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.K = K
        self.attention = attention2d(in_planes, ratio, K, temperature)
        self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(K, out_planes)) if bias else None
        if init_weight:
            self._initialize_weights()
    def _initialize_weights(self):
        for i in range(self.K):
            nn.init.kaiming_uniform_(self.weight[i], mode='fan_in', nonlinearity='relu')
    def update_temperature(self):
        self.attention.update_temperature()
    def forward(self, x):
        B, C, H, W = x.size()
        att = self.attention(x)  # [B, K]
        x_flat = x.view(1, B*C, H, W)
        w_flat = self.weight.view(self.K, -1)
        agg_w = torch.mm(att, w_flat).view(B * self.out_planes, C // self.groups, self.kernel_size, self.kernel_size)
        if self.bias is not None:
            agg_b = torch.mm(att, self.bias).view(-1)
            out = F.conv2d(x_flat, agg_w, bias=agg_b,
                           stride=self.stride, padding=self.padding,
                           dilation=self.dilation, groups=self.groups * B)
        else:
            out = F.conv2d(x_flat, agg_w, bias=None,
                           stride=self.stride, padding=self.padding,
                           dilation=self.dilation, groups=self.groups * B)
        return out.view(B, self.out_planes, out.size(-2), out.size(-1))

def conv3x3(in_planes, out_planes, stride=1):
    return Dynamic_conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                          padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
    return Dynamic_conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, norm_layer=None):
        super(BasicBlock, self).__init__()
        norm_layer = norm_layer or nn.BatchNorm2d
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(3, 2, 1)
        self.layer1 = self._make_layer(block, 64,  layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(p=0.1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def _make_layer(self, block, planes, blocks, stride=1):
        norm_layer = nn.BatchNorm2d
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )
        layers = [block(self.inplanes, planes, stride, downsample, norm_layer)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
        return nn.Sequential(*layers)
    def update_temperature(self):
        for child in self.children():
            if hasattr(child, "update_temperature"):
                child.update_temperature()
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x); x = self.layer2(x)
        x = self.layer3(x); x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        return self.fc(x)

def dy_resnet18(num_classes):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)

# %% [markdown]
# ## 3. 數據預處理與 DataLoader 設定

# %% [code]
data_root = "."
train_txt = os.path.join(data_root, "train.txt")
val_txt   = os.path.join(data_root, "val.txt")
test_txt  = os.path.join(data_root, "test.txt")
img_root  = os.path.join(data_root)

train_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
eval_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_dataset = TxtImageDataset(train_txt, root_dir=img_root, transform=train_transform)
val_dataset   = TxtImageDataset(val_txt,   root_dir=img_root, transform=eval_transform)
test_dataset  = TxtImageDataset(test_txt,  root_dir=img_root, transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,  num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=64, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False, num_workers=4)

# %% [markdown]
# ## 4. 訓練與驗證函數 & 訓練迴圈

# %% [code]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = dy_resnet18(num_classes=len(train_dataset.samples)).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
num_epochs = 90

with open("training_results.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_loss", "train_acc", "val_loss", "val_acc"])

def train_epoch():
    model.train()
    total_loss, correct = 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        for i in range(imgs.size(0)):
            if random.random() < 0.5:
                drop_n = 1 if random.random()<0.67 else 2
                for ch in random.sample([0,1,2], drop_n):
                    imgs[i,ch] = 0
        optimizer.zero_grad()
        out = model(imgs)
        loss = F.cross_entropy(out, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        correct += (out.argmax(1) == labels).sum().item()
    return total_loss/len(train_loader.dataset), correct/len(train_loader.dataset)

def evaluate(loader):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            out = model(imgs)
            total_loss += F.cross_entropy(out, labels, reduction="sum").item()
            correct += (out.argmax(1) == labels).sum().item()
    return total_loss/len(loader.dataset), correct/len(loader.dataset)

best_val_acc = 0.0
for epoch in range(1, num_epochs+1):
    tr_loss, tr_acc = train_epoch()
    val_loss, val_acc = evaluate(val_loader)
    print(f"Epoch {epoch}: Train {tr_loss:.4f}/{tr_acc:.4f} | Val {val_loss:.4f}/{val_acc:.4f}")
    with open("training_results.csv", "a", newline="") as f:
        csv.writer(f).writerow([epoch, tr_loss, tr_acc, val_loss, val_acc])
    model.update_temperature()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_dy_resnet18.pth")
print(f"Best validation acc: {best_val_acc:.4f}")

# %% [markdown]
# ## 5. 測試：多通道組合推論

# %% [code]
def modify_channels(x, combo):
    mapping = {"R":0, "G":1, "B":2}
    out = torch.zeros_like(x)
    for c in combo:
        out[:, mapping[c]] = x[:, mapping[c]]
    return out

combos = ["RGB","RG","RB","GB","R","G","B"]
results = {c:{"correct":0,"total":0} for c in combos}

model.eval()
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        for combo in combos:
            mod = modify_channels(imgs, combo)
            out = model(mod)
            results[combo]["correct"] += (out.argmax(1)==labels).sum().item()
            results[combo]["total"]   += labels.size(0)

with open("channel_accuracy_dy.csv","w",newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["channel_combo","accuracy"])
    for c in combos:
        acc = results[c]["correct"]/results[c]["total"]
        writer.writerow([c, acc])
        print(f"{c}: {acc:.4f}")

# %% [markdown]
# ## 6. 顯示論文表格結果：Training/Val、Channel Accuracy、Model Complexity

# %% [code]
print("=== Training & Validation ===")
df_tr = pd.read_csv("training_results.csv")
print(df_tr.to_string(index=False))

print("\n=== Channel Accuracy ===")
df_ch = pd.read_csv("channel_accuracy_dy.csv")
print(df_ch.to_string(index=False))

print("\n=== Model Complexity (ptflops) ===")
try:
    from ptflops import get_model_complexity_info
    macs, params = get_model_complexity_info(model, (3,224,224),
                                             as_strings=True, print_per_layer_stat=False)
    print(f"Parameters (ptflops): {params}")
    print(f"FLOPS      (ptflops): {macs}")
except ImportError:
    print("請先安裝 ptflops：pip install ptflops")

# —— 額外使用 thop 計算 FLOPs 與 Params —— 
from thop import profile

dummy_input = torch.randn(1, 3, 224, 224).to(device)
flops, params = profile(model, inputs=(dummy_input,), verbose=False)
print(f"\n=== Model Complexity (thop) ===")
print(f"THOP Params: {params:,}  FLOPs: {flops:,}")


Epoch 1: Train 3.7822/0.0756 | Val 3.3454/0.1422
Epoch 2: Train 3.1482/0.1560 | Val 3.0401/0.1778
Epoch 3: Train 2.9061/0.2096 | Val 3.3838/0.2356
Epoch 4: Train 2.6845/0.2559 | Val 2.6246/0.2711
Epoch 5: Train 2.3745/0.3257 | Val 2.4049/0.3267
Epoch 6: Train 2.1149/0.3923 | Val 2.2974/0.3489
Epoch 7: Train 1.9028/0.4406 | Val 1.7976/0.4711
Epoch 8: Train 1.7324/0.4861 | Val 1.6680/0.4933
Epoch 9: Train 1.6036/0.5210 | Val 1.6275/0.5356
Epoch 10: Train 1.4929/0.5487 | Val 1.4649/0.5311
Epoch 11: Train 1.4208/0.5705 | Val 1.5268/0.5533
Epoch 12: Train 1.3423/0.5905 | Val 1.5081/0.5578
Epoch 13: Train 1.2804/0.6082 | Val 1.2821/0.6200
Epoch 14: Train 1.2315/0.6225 | Val 1.6187/0.5578
Epoch 15: Train 1.1938/0.6315 | Val 1.2918/0.6267
Epoch 16: Train 1.1454/0.6446 | Val 1.3201/0.5822
Epoch 17: Train 1.1127/0.6535 | Val 1.3411/0.5911
Epoch 18: Train 1.0890/0.6608 | Val 1.3062/0.5978
Epoch 19: Train 1.0589/0.6675 | Val 1.2375/0.6089
Epoch 20: Train 1.0343/0.6767 | Val 1.3187/0.5956
Epoch 21: