In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed 
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)[..., None, None]

class ScoreNet(nn.Module):
  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
    super().__init__()
    self.embed = nn.Sequential(
        GaussianFourierProjection(embed_dim=embed_dim),
        nn.Linear(embed_dim, embed_dim)
    )
    
    self.conv1 = nn.Conv2d(3, channels[0], 3, stride=1, padding=1, bias=False)  # 入力チャネル変更
    self.dense1 = Dense(embed_dim, channels[0])
    self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])

    self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, padding=1, bias=False)
    self.dense2 = Dense(embed_dim, channels[1])
    self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

    self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, padding=1, bias=False)
    self.dense3 = Dense(embed_dim, channels[2])
    self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])

    self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, padding=1, bias=False)
    self.dense4 = Dense(embed_dim, channels[3])
    self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])    

    self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 4, stride=2, padding=1, bias=False)
    self.dense5 = Dense(embed_dim, channels[2])
    self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])

    self.tconv3 = nn.ConvTranspose2d(channels[2]+channels[2], channels[1], 4, stride=2, padding=1, bias=False)
    self.dense6 = Dense(embed_dim, channels[1])
    self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])

    self.tconv2 = nn.ConvTranspose2d(channels[1]+channels[1], channels[0], 4, stride=2, padding=1, bias=False)
    self.dense7 = Dense(embed_dim, channels[0])
    self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])

    self.tconv1 = nn.ConvTranspose2d(channels[0]+channels[0], 3, 3, stride=1, padding=1)  # 出力チャネル変更

    self.act = lambda x: x * torch.sigmoid(x)
    self.marginal_prob_std = marginal_prob_std

  def forward(self, x, t): 
    embed = self.act(self.embed(t))
    h1 = self.act(self.gnorm1(self.conv1(x) + self.dense1(embed)))
    h2 = self.act(self.gnorm2(self.conv2(h1) + self.dense2(embed)))
    h3 = self.act(self.gnorm3(self.conv3(h2) + self.dense3(embed)))
    h4 = self.act(self.gnorm4(self.conv4(h3) + self.dense4(embed)))

    h = self.act(self.tgnorm4(self.tconv4(h4) + self.dense5(embed)))
    h = self.act(self.tgnorm3(self.tconv3(torch.cat([h, h3], dim=1)) + self.dense6(embed)))
    h = self.act(self.tgnorm2(self.tconv2(torch.cat([h, h2], dim=1)) + self.dense7(embed)))
    h = self.tconv1(torch.cat([h, h1], dim=1))

    h = h / self.marginal_prob_std(t)[:, None, None, None]
    return h

In [None]:
# -------------------------------
# モデル・関数の事前定義（省略していた場合は必要）
# → ScoreNet, marginal_prob_std_fn, loss_fn 等は定義済みと仮定
# -------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import h5py
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import tqdm
import os

class PCamDataset_k(Dataset):
    def __init__(self, x_path, y_path, transform=None):
        super().__init__()
        self.x_h5 = h5py.File(x_path, 'r')
        self.y_h5 = h5py.File(y_path, 'r')
        self.transform = transform

    def __len__(self):
        return self.x_h5['x'].shape[0]

    def __getitem__(self, idx):
        img = self.x_h5['x'][idx]
        if img.ndim == 2:
            img = np.stack([img]*3, axis=-1)
        elif img.shape[2] == 1:
            img = np.concatenate([img]*3, axis=2)
        if self.transform:
            img = self.transform(img)
        label = self.y_h5['y'][idx]
        return img, label


# ----------- パラメータ -------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_epochs = 500
batch_size = 16
lr = 2e-4
sigma = 10.0
marginal_prob_std_fn = functools.partial(lambda t, sigma: torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma)), sigma=sigma)

# ----------- モデルとデータ -------------
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),  # 小さめの回転
    transforms.ColorJitter(brightness=0.1, contrast=0.1),  # 軽微な色変化
    transforms.ToTensor()
])

dataset = PCamDataset_k(
    '/home/gotou/Medical/camelyonpatch_level_2_split_train_x.h5',
    '/home/gotou/Medical/camelyonpatch_level_2_split_train_y.h5',
    transform=transform
)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

score_model = ScoreNet(marginal_prob_std=marginal_prob_std_fn).to(device)
optimizer = Adam(score_model.parameters(), lr=lr, betas=(0.9, 0.99))
best_loss = 1000


# ----------- 損失関数 -------------
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None, None]
    score = model(perturbed_x, random_t)
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
    return loss

# ----------- 学習ループ -------------
os.makedirs("checkpoints", exist_ok=True)
tqdm_epoch = tqdm.trange(n_epochs, mininterval=1.0)
for epoch in tqdm_epoch:
    avg_loss = 0.
    num_items = 0
    score_model.train()
    for i, (x, _) in enumerate(data_loader):
        x = x.to(device)
        loss = loss_fn(score_model, x, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
        if i >= 256:  # 必要なら短縮可
            break

    tqdm_epoch.set_description(f"Epoch {epoch+1} | Loss: {avg_loss / num_items:.4f}")

    if avg_loss / num_items < best_loss:
        best_loss = avg_loss / num_items
        torch.save(score_model.state_dict(), 'checkpoints/best_model.pth')


print(" 学習完了")
  



In [None]:
# 学習ループがすべて終わった後にモデルを保存
torch.save(score_model.state_dict(), 'checkpoints/final_model.pth')
print("最終モデルを checkpoints/final_model.pth に保存しました。")
