<a href="https://colab.research.google.com/github/eR3R3/EEG_Classification/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import numpy as np
import torch.optim as optim
import os
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from typing import Dict, List, Optional, Union, Tuple, Iterable

In [None]:
device = torch.device('cpu')
print(device)

import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # 编码器部分（下采样）
        self.enc1 = self.conv_block(3, 320)
        self.enc2 = self.conv_block(320, 640)
        self.enc3 = self.conv_block(640, 1280)

        # 解码器部分（上采样）
        self.upconv3 = self.upconv_block(1280, 640)
        self.upconv2 = self.upconv_block(640, 320)

        # 全局池化 + 最终全连接层
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)  # 池化到 (batch_size, 320, 1, 1)
        self.fc = nn.Linear(320, 10)  # 输出 (batch_size, 10)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )
        return block

    def upconv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )
        return block

    def forward(self, x):
        # 编码器部分
        enc1_out = self.enc1(x)  # (batch_size, 320, H, W)
        enc2_out = self.enc2(enc1_out)  # (batch_size, 640, H//2, W//2)
        enc3_out = self.enc3(enc2_out)  # (batch_size, 1280, H//4, W//4)

        # 解码器部分
        upconv3_out = self.upconv3(enc3_out)  # (batch_size, 640, H//2, W//2)
        upconv3_out = F.interpolate(upconv3_out, size=enc2_out.shape[2:])  # 调整形状
        upconv3_out = upconv3_out + enc2_out  # 跳跃连接

        upconv2_out = self.upconv2(upconv3_out)  # (batch_size, 320, H, W)
        upconv2_out = F.interpolate(upconv2_out, size=enc1_out.shape[2:])  # 调整形状
        upconv2_out = upconv2_out + enc1_out  # 跳跃连接

        # 全局池化和分类
        out = self.global_avg_pool(upconv2_out)  # (batch_size, 320, 1, 1)
        out = out.view(out.size(0), -1)  # (batch_size, 320)
        out = self.fc(out)  # (batch_size, 10)
        return out



cpu


In [None]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

meta_file_path = "/content/drive/MyDrive/dataset/cifer_10/batches.meta"
meta_file = unpickle(meta_file_path)
label_name = meta_file[b'label_names']
label_name = [byte.decode() for byte in label_name]
print(label_name[0])


airplane


In [None]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

file = "/content/drive/MyDrive/dataset/cifer_10/data_batch_1"
img_data = unpickle(file)
img_data_keys = list(img_data.keys())[:4]
img_filename = img_data[b'filenames']
img_label = img_data[b'labels']
img_data = img_data[b'data']
img_data = torch.from_numpy(img_data)
img_label = torch.tensor(img_label)
img_data = img_data.to(device)
img_label = img_label.to(device)
print(img_data_keys)
print(img_data.shape)
print(img_label.shape)

[b'batch_label', b'labels', b'data', b'filenames']
torch.Size([10000, 3072])
torch.Size([10000])


In [None]:
vision_model = UNet()
vision_model.to(device)

UNet(
  (enc1): Sequential(
    (0): Conv2d(3, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (enc2): Sequential(
    (0): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (enc3): Sequential(
    (0): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(1280, eps=1e-05

In [None]:
def preprocess(data):
  img_data = data.reshape(-1, 3, 32, 32)
  img_data = img_data.float()/255.0
  mean = torch.tensor([0.485, 0.456, 0.406])
  std = torch.tensor([0.229, 0.224, 0.225])
  mean = mean[None, :, None, None]
  std = std[None, :, None, None]
  normalized_images = (img_data - mean) / std
  return normalized_images
img_data = preprocess(img_data)
print(img_data.shape)

class VisionDataset(Dataset):
  def __init__(self, data, label, preprocess):
    super().__init__()
    # data: [num_batches, num_channel, height, width]
    self.label = label
    self.preprocess = preprocess
    self.data = preprocess(data)
  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img = self.data[idx]
    label = self.label[idx]
    return img, label


train_set = VisionDataset(img_data[0:8000], img_label[0:8000], preprocess)
val_set = VisionDataset(img_data[8000:9984], img_label[8000:9984], preprocess)

batch_size = 32
train_loader = DataLoader(train_set, batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size, shuffle=False)


for num, (img, label) in enumerate(train_loader):
  max_list = []
  logits = vision_model(img)
  pred = nn.functional.softmax(logits, dim=-1)
  print(pred.shape)
  for each_list in pred:
    max_value, max_index = each_list.max(0)
    max_list.append(max_index.item())
    pred = [label_name[i] for i in max_list]
    value = [label_name[i] for i in label]
  print(f"this is the {num} batch")
  print(pred)
  print(value)
  break

torch.Size([10000, 3, 32, 32])
torch.Size([32, 10])
this is the 0 batch
['automobile', 'automobile', 'ship', 'dog', 'cat', 'frog', 'dog', 'horse', 'cat', 'cat', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'automobile', 'automobile', 'automobile', 'dog', 'ship', 'ship', 'cat', 'cat', 'deer', 'cat', 'ship', 'cat', 'automobile', 'ship', 'dog']
['bird', 'bird', 'ship', 'deer', 'cat', 'automobile', 'bird', 'frog', 'truck', 'horse', 'automobile', 'deer', 'dog', 'truck', 'frog', 'frog', 'automobile', 'bird', 'dog', 'dog', 'frog', 'ship', 'airplane', 'frog', 'cat', 'deer', 'cat', 'airplane', 'frog', 'bird', 'truck', 'automobile']


In [None]:
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

optimizer = Adam(vision_model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

best_val_loss = float('inf')
patience = 10
counter = 0

for epoch in range(100):  # 最大训练 100 个 epoch
    # 训练阶段
    vision_model.train()
    for img, label in train_loader:
        optimizer.zero_grad()
        output = vision_model(img)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

    # 验证阶段
    vision_model.eval()
    val_loss = 0
    with torch.no_grad():
        for img, label in val_loader:
            output = vision_model(img)
            val_loss += criterion(output, label).item()
    val_loss /= len(val_loader)

    print(f"Epoch {epoch}, Validation Loss: {val_loss:.4f}")

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0  # 重置 patience 计数
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break