In [None]:
import torch
torch.cuda.is_available()

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange,Reduce
from torchvision.io import read_image
import torchvision
import itertools
from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
from torchvision.utils import make_grid, save_image
from torchvision.transforms import v2
import math
import numpy as np
import icecream
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp

import sys
sys.path.append("..")
sys.path.append(".")

from util.debug_print import dprint
from setting.setting import *
import util.my_transform as mt

debug=True


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=DEVICE)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
    
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(Conv2d,self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)


    def forward(self, x):
        return self.conv(x)

    def diversity_loss(self):
        weight = self.conv.weight.view(self.conv.out_channels, -1)  # (out_channels, in_channels * kernel_size * kernel_size)
        weight_mean = torch.mean(weight, dim=1, keepdim=True)  # (out_channels, 1)
        weight_centered = weight - weight_mean  # (out_channels, in_channels * kernel_size * kernel_size)
        covariance_matrix = torch.matmul(weight_centered, weight_centered.t())  # (out_channels, out_channels)
        diag = torch.diag(covariance_matrix)
        covariance_matrix = covariance_matrix - torch.diag_embed(diag)
        diversity_loss = torch.mean(torch.abs(covariance_matrix))  # Sum of absolute values of off-diagonal elements
        return diversity_loss

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels,kernel_size=3,stride=2,padding=1):
        super().__init__()
        self.Block = nn.Sequential(
            Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )
    def forward(self, x):
        x = self.Block(x)
        return x

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels,kernel_size=3,stride=2,padding=1):
        super().__init__()
        self.Block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )

    def forward(self, x):
        return self.Block(x)
    
class AffinDiffusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.batchnorm = nn.BatchNorm2d(C1)
        self.DownBlock = nn.Sequential(
            DownBlock(C1, C2,kernel_size=4,stride=4,padding=0),
            DownBlock(C2, C3,kernel_size=4,stride=4,padding=0),
            DownBlock(C3, C4,kernel_size=4,stride=4,padding=0),
            DownBlock(C4, C4,kernel_size=2,stride=2,padding=0),
        )

        encoder_layer = nn.TransformerEncoderLayer(d_model=C4*2*2, nhead=8,batch_first=True)
        self.transformerencoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
        # self.reduce = Reduce("b c h w -> b c",reduction="mean")
        self.rearrange = Rearrange("b c h w -> b (c h w)")
        self.mlp = nn.Sequential(
            nn.Linear(C4*2*2,DENOISE_PARAMS,bias=True),
            # nn.ReLU(),
            nn.Sigmoid(),
        )

        # self.mlp2_latent = nn.Sequential(
        #     nn.Linear(DENOISE_PARAMS,C5),
        #     nn.SiLU(),
        # )

        decoder_layer = nn.TransformerDecoderLayer(d_model=C4*2*2, nhead=8,batch_first=True)
        self.transformerdecoder = nn.TransformerDecoder(decoder_layer, num_layers=3)

        self.rearrange2 = Rearrange("b (c h w) -> b c h w",c=C4,h=2,w=2)
        
        self.UpBlock = nn.Sequential(
            UpBlock(C4, C4,kernel_size=2,stride=2,padding=0),
            UpBlock(C4, C4,kernel_size=2,stride=2,padding=0),
            UpBlock(C4, C3,kernel_size=2,stride=2,padding=0),
            UpBlock(C3, C3,kernel_size=2,stride=2,padding=0),
            UpBlock(C3, C2,kernel_size=2,stride=2,padding=0),
            UpBlock(C2, C2,kernel_size=2,stride=2,padding=0),
            UpBlock(C2, C1,kernel_size=2,stride=2,padding=0),
        )

        self.mlp2 = nn.Sequential(
            nn.Linear(2,C4*2*2),
            nn.SiLU(),
        )

        # self.Embedding = nn.Parameter(torch.zeros(B,C,H,W),requires_grad=True)
        self.t = TimestepEmbedder(hidden_size=C4*2*2)
        # self.Embedding2 = nn.Parameter(torch.zeros(B,C4*2*2),requires_grad=True)
        self.Embedding3 = nn.Parameter(torch.zeros(B,C4,2,2),requires_grad=True)
        self.Embedding4 = nn.Parameter(torch.zeros(B,C4,2,2),requires_grad=True)
        # self.Embedding = nn.Embedding(B,C5)
        # self.Embedding.weight.data = torch.randn(B,C5)
        # self.embedding =
        # self.Embedding = nn.Parameter(torch.randn(B,C,H,W))
        # self.Transformer = nn.Transformer(d_model=C4, nhead=16, num_encoder_layers=12,batch_first=True)
        # self.reduce = Reduce("b c h w -> b c",reduction="mean")
        # self.xy = nn.Sequential(
        #     nn.Linear(DENOISE_PARAMS,2),
            # nn.Sigmoid(),
        # )
        
        # self.apply(self.weight_init)

    def weight_init(self,module):
        self.mlp.weight.data = torch.zeros(DENOISE_PARAMS,C4) + 0.001
        self.mlp.bias.data = torch.zeros(DENOISE_PARAMS) + 0.001

    #     if isinstance(module, nn.ConvTranspose2d):
    #         module.weight.data = torch.zeros_like(module.weight.data) + 0.001
            # if module.bias is not None:
            #     module.bias.data = torch.zeros_like(module.bias.data) + 0.001

    def forward(self, x,i=torch.tensor([100],device=DEVICE)):
        # c,h,w = x.shape[1],x.shape[2],x.shape[3]
        # t = self.Embedding * i
        x = self.batchnorm(x)
        dprint(f"img.shape",x.shape,debug)
        latent = self.DownBlock(x)
        dprint(f"latent.shape",latent.shape,debug)
        latent = self.rearrange(latent)
        
        time_embedding = self.t(i)
        latent = latent + time_embedding
        
        dprint(f"latent.shape",latent.shape,debug)
        for i in range(TRANSFORMER_RECURSION):
            latent = self.transformerencoder(latent)
        mlp = self.mlp(latent)
        # mlp2latent = self.mlp2_latent(mlp)
        dprint(f"mlp",mlp.shape,debug)

        # q = self.Embedding2 *mlp[0][0]
        # self.Embedding3 = torch.matmul(self.Embedding3 ,mlp[:,1])
        # self.Embedding4 = torch.matmul(self.Embedding4,mlp[:,2])
        noise = mlp[:,1:]
        noise = self.mlp2(noise)
        dprint("noise.shape",noise.shape,debug)
        # t =  * i
        # s = 0#self.Embedding * mlp
        # tgt = latent + q + r + s
        # tgt = F.normalize(tgt,dim=1)
        # dprint("tgt.shape",tgt.shape,debug)
        # for i in range(TRANSFORMER_RECURSION):
        #     tgt = self.transformerdecoder(latent,latent)
        latent = self.rearrange2(latent)
        noise = self.rearrange2(noise)
        dprint("noise.shape",noise.shape,debug)
        dprint("latent.shape",latent.shape,debug)
        out_t = self.UpBlock(latent)
        noise = self.UpBlock(noise)
        dprint("out_t.shape",out_t.shape,debug)
        dprint("noise.shape",noise.shape,debug)
        # noise = out[:,C1:,:,:]
        # out = out[:,:C1,:,:]
        # dprint(f"out.shape",out.shape,debug)
        # x = v2.functional.resize(x,size=(int(H*SCALE_FACTOR**mlp[0].item()),int(W*SCALE_FACTOR**mlp[0][0].item())))
        # de_affine = mt.Affine(size=(H,W),degrees=-mlp[3].item(),translate1=0,translate2=0,scale=1-mlp[6].item())(x)
        out = out_t.clone()
        for i in range(x.shape[0]):
            affine = mt.Affine(size=(H,W),scale=math.pow(SCALE_FACTOR,mlp[i][0].item()))(out_t[i])
            out[i] = affine + noise[i]
        # out = F.
        # out = mlp[0][1] * x + (1-mlp[0][1]) * out
        return out_t,mlp,out,noise

if __name__ == "__main__":
    # img = read_image(TEST_IMG).float().to("cuda:0", non_blocking=True)
    # img = img.unsqueeze(0)
    # img = mt.Resize(size=(H,W))(img)

    model = AffinDiffusion().to(DEVICE, non_blocking=True)
    # out = model(img,100)
    # del img,out,model
    # torch.cuda.empty_cache()
    from torchinfo import summary
    summary = summary(model,input_size=(2,C,H,W),col_names=["output_size", "num_params"],row_settings=["var_names"])
    print(summary)
    del model,summary
    # print(torch.cuda.memory_summary())




In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as tF
import cv2
plt.rcParams["savefig.bbox"] = 'tight'


def show(out,img):
    with torch.no_grad():
        pred = out[0,:,:,:]
        orig = img[0,:,:,:]
        imgs = torchvision.utils.make_grid([pred,orig],nrow=2)
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = tF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

def show_cv2(out,img):
    cv2.imshow("out",out)
    cv2.imshow("img",img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

def label_loss(label,mlp):
    label = torch.tensor(label,dtype=torch.float32)
    mlp = mlp.view(5)
    return torch.norm(label - mlp, dim=0)

def diversity_loss(model):
    return sum([i.Block[0].diversity_loss() for i in model.DownBlock])

class Zero255Loss(nn.Module):
    def __init__(self):
        super(Zero255Loss, self).__init__()
    def forward(self, generated_image):
        # 生成画像のテンソルを取得
        image_tensor = generated_image.detach().cpu().numpy()
        # 総ピクセル数を計算
        total_pixels = image_tensor.size
        # 値が0のピクセル数をカウント
        zero_pixel_count = np.sum(image_tensor <= 0)
        # 値が255のピクセル数をカウント
        max_pixel_count = np.sum(image_tensor >= 255)
        # 0と255のピクセルの割合を計算
        ratio = (zero_pixel_count + max_pixel_count) / total_pixels
        # 損失として割合を返す
        loss = torch.tensor(ratio, dtype=torch.float32)
        return loss

def fid(img,out):
    fid = FrechetInceptionDistance(feature=64)
    img = img.to(torch.int8)
    out = out.detach().to(torch.int8)
    fid.update(img, real=True)
    fid.update(out, real=False)
    return fid.compute()

def ms_ssim(img,out):
    ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
    return ms_ssim(img, out)

def label_criterion(label,mlp):
    label = torch.tensor(label,dtype=torch.float32)
    mlp = mlp.view(DENOISE_PARAMS)
    return torch.norm(label - mlp, dim=0)


def save_weights(model,path):
    torch.save(model.state_dict(), path)

def load_weights(model,path):
    model.load_state_dict(torch.load(path))
    return model



In [None]:
from datasets import load_dataset

ds = load_dataset("sasha/dog-food")
ds = ds.with_format("torch", device=DEVICE)

In [None]:
debug=False
image = read_image(TEST_IMG)
# image = mt.Resize(size=(H,W))(image)
image = image/255
image = image.unsqueeze(0)
*_,h,w = image.shape

model = AffinDiffusion().to(DEVICE)
model = load_weights(model,"weights.pth")

zero255_loss = Zero255Loss()
ergas = ErrorRelativeGlobalDimensionlessSynthesis()
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)


for batch in range(10000):
    for data in ds["train"]:
        image = data["image"]
        c,h,w = image.shape
        print(c,h,w)
        # image = read_image(image)
        image = image/255
        image = image.unsqueeze(0)
        if image.shape[2] < H or image.shape[3] < W:
            continue
    
        if w == W:
            x = 0
            y = 0
            xy = torch.zeros(2)
        else:
            x = np.random.randint(0,w-W)
            y = np.random.randint(0,h-H)
            xy = torch.tensor([x/(w-W),y/(h-H)]).float()
        img = mt.Crop(size=(H,W),point=(y,x))(image).float()
        noise_list = mt.generate_noise(num=DENOISE_TIME,noise_scale=NOISE_SCALE)
        mx = my = np.linspace(2,H,DENOISE_TIME).astype(np.int32)
        

        for i in reversed(range(0,DENOISE_TIME-1)):
            optimizer.zero_grad()
            noisy_img,label = mt.noisy_image_generator(img,noise_list,i)
            noisy_img_t,label_t = mt.noisy_image_generator(img,noise_list,i-1)
        
            label_t = label_t.unsqueeze(0)
            noisy_img = noisy_img.to(DEVICE)
            # label = label
            noisy_img_t = noisy_img_t.to(DEVICE)
            # label_t = label_t
            
            out_t,mlp,out, noise = model(noisy_img,torch.tensor([i],device=DEVICE))




            # sigma = mlp[0][2].item()-label_t[0][2].item()
            # if sigma < 0:
            #     sigma = 0
            # affine_out = mt.Affine(size=(H,W),scale=np.power(SCALE_FACTOR,mlp[0][0].item()))(out)
            # pred_noise = mt.GaussianNoise(mean=mlp[0][1].item(),sigma=sigma)(out)
            # pred = affine_out #- pred_noise
            out_t = out_t.cpu()
            out = out.cpu()
            mlp = mlp.cpu()
            noise = noise.cpu()
            noisy_img = noisy_img.cpu()
            noisy_img_t = noisy_img_t.cpu()
            label_t = label_t.cpu()
            label = label.cpu()



            
            
            # pred_xy = pred_xy.squeeze(0).cpu()
            # pred_xy_ind = pred_xy * H
            # pred_xy_ind = pred_xy_ind.int()
            xy_loss = 0#criterion(xy,pred_xy)
            label_loss = criterion(label_t.mean(dim=0),mlp.mean(dim=0))
            # pixel_loss = zero255_loss(out_t)
            weight_loss = diversity_loss(model)
            ergas_loss = ergas(out_t, noisy_img_t)
            mse_loss = criterion(out_t, noisy_img_t)
            noise_mse_loss = criterion(out,noisy_img)
            loss_of_loss = (mse_loss-noise_mse_loss)**2
            loss_mean = criterion(noise.mean(),mlp[:,1])
            loss_var = criterion(noise.var(),mlp[:,2])
            # ms_ssim_loss = 1-ms_ssim(out_t, noisy_img_t)
            # icecream.ic(batch,i,label_loss,mse_loss,noise_mse_loss,loss_of_loss,loss_mean,loss_var,ergas_loss,weight_loss)
            loss =  label_loss + mse_loss + noise_mse_loss + loss_of_loss + loss_mean + loss_var + ergas_loss + weight_loss  #pixel_loss #+ ms_ssim_loss
            loss.backward()
            optimizer.step()
            if i % (DENOISE_TIME//10) == 0 :
                icecream.ic(label_t,mlp,noise.mean(),noise.var())
                icecream.ic(batch,i,label_loss,mse_loss,noise_mse_loss,loss_of_loss,loss_mean,loss_var,ergas_loss,weight_loss)
                # outとnoisy_imgを結合
                show(out_t,noisy_img_t)

    
        
        grid_images = torch.cat([out_t,noisy_img_t], dim=0)
        grid = make_grid(grid_images, nrow=2, normalize=True, padding=2)
        save_image(grid, RESULT_DIR / f"grid_image_batch_{batch}.png")

        save_weights(model,"weights.pth")   

        if batch % 1== 0:
            with torch.no_grad():
                noisy_image,label = mt.noisy_image_generator(img,noise_list,DENOISE_TIME-1)
                noisy_image = noisy_image.to(DEVICE)
                out = noisy_image.clone().to(DEVICE)
                # noisy_img.requires_grad = True
                label = label.to(DEVICE)
                for i in reversed(range(1,DENOISE_TIME)):
                    out,_ ,_,_= model(out,torch.tensor([i],device=DEVICE))
                    if i % (DENOISE_TIME) == 0 :
                        grid_images = torch.cat([out.detach().cpu(),noisy_image.detach().cpu()],dim=0)
                        grid = make_grid(grid_images, nrow=2, normalize=True, padding=3)
                        # save_image(grid, RESULT_DIR / f"grid_image_batch_{batch}.png")
                        img = tF.to_pil_image(grid)
                        fig, axs = plt.subplots(ncols=1, squeeze=False)
                        axs[0, 0].imshow(img)

In [None]:
model = AffinDiffusion().to(DEVICE)
model = load_weights(model,"weights.pth")

model.eval()
img = torch.randn(1,3,H,W)
img = img.to(DEVICE)
for i in range(DENOISE_TIME):
    img,_,_ = model(img)
    if i % 10 == 0:
        show(out,img)

In [None]:
import torch
# CUDAが利用可能かどうかを確認
cuda_available = torch.cuda.is_available()
print(f"CUDA利用可能: {cuda_available}")

if cuda_available:
    # 利用可能なGPUデバイスの数を表示
    gpu_count = torch.cuda.device_count()
    print(f"利用可能なGPUデバイス数: {gpu_count}")
    
    # 現在のGPUデバイスの名前を表示
    current_device = torch.cuda.current_device()
    device_name = torch.cuda.get_device_name(current_device)
    print(f"現在のGPUデバイス: {device_name}")
else:
    print("CUDAが利用できません。CPUで実行されます。")
