<a href="https://colab.research.google.com/github/frzlh/DeepLearning/blob/main/Cnnhuigui.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
!tar xf lfw-deepfunneled.tgz
!mkdir lfw-deepfunneled/train
!mkdir lfw-deepfunneled/test
!mv lfw-deepfunneled/[A-W]* lfw-deepfunneled/train
!mv lfw-deepfunneled/[X-Z]* lfw-deepfunneled/test

--2024-08-20 14:34:31--  http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
Resolving vis-www.cs.umass.edu (vis-www.cs.umass.edu)... 128.119.244.95
Connecting to vis-www.cs.umass.edu (vis-www.cs.umass.edu)|128.119.244.95|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 108761145 (104M) [application/x-gzip]
Saving to: ‘lfw-deepfunneled.tgz.3’


In [None]:
import torch
from torch import nn,optim
from torchvision.datasets import ImageFolder
import tqdm
from torchvision import transforms
from torch.utils.data import(Dataset,DataLoader,TensorDataset)
class DownSizedPairImageFolder(ImageFolder):
  def __init__(self,root,transform=None,large_size=128,small_size=32,**kwds):
    super().__init__(root,transform=transform,**kwds)
    self.large_resizer=transforms.Resize(large_size)
    self.small_resizer=transforms.Resize(small_size)

  def __getitem__(self, index):
    path,_=self.imgs[index]
    img=self.loader(path)

    large_img=self.large_resizer(img)
    small_img=self.small_resizer(img)

    if self.transform is not None:
      large_img=self.transform(large_img)
      small_img=self.transform(small_img)

    return small_img,large_img

用于训练和验证

In [None]:
train_data=DownSizedPairImageFolder("lfw-deepfunneled/train",transform=transforms.ToTensor())
test_data=DownSizedPairImageFolder("lfw-deepfunneled/test",transform=transforms.ToTensor())

batch_size=32
train_loader=DataLoader(train_data,batch_size=batch_size,shuffle=True,num_workers=2)
test_loader=DataLoader(test_data,batch_size=batch_size,shuffle=True,num_workers=2)

模型创建

In [None]:
net=nn.Sequential(
    nn.Conv2d(3,256,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.Conv2d(256,512,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(512),
    nn.ConvTranspose2d(512,256,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.ConvTranspose2d(256,128,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.ConvTranspose2d(128,64,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64,3,4,stride=2,padding=1)

)

验证用辅助函数

In [None]:
import math
def psnr(mse,max_v=1.0):
  return 10*math.log10(max_v**2/mse)

def eval_net(net,data_loader,device="cpu"):
  net.eval()
  ys=[]
  ypreds=[]
  for x,y in data_loader:
    x=x.to(device)
    y=y.to(device)
    with torch.no_grad():
      y_pred=net(x)
    ys.append(y)
    ypreds.append(y_pred)
  ys=torch.cat(ys)
  ypreds=torch.cat(ypreds)
  score=nn.functional.mse_loss(ypreds,ys).item()
  return score

训练用辅助函数

In [None]:
def train_net(net,train_loader,test_loader,optimizer_cls=optim.Adam,loss_fn=nn.MSELoss(),n_iter=10,device="cpu"):
  train_losses=[]
  train_acc=[]
  val_acc=[]
  optimizer=optimizer_cls(net.parameters())
  for epoch in range(n_iter):
    running_loss=0.0
    net.train()
    n=0
    score=0
    for i,(xx,yy) in tqdm.tqdm(enumerate(train_loader),total=len(train_loader)):
      xx=xx.to(device)
      yy=yy.to(device)
      y_pred=net(xx)
      loss=loss_fn(y_pred,yy)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      running_loss+=loss.item()
      n+=len(xx)
    train_losses.append(running_loss/len(train_loader))
    val_acc.append(eval_net(net,test_loader,device))
    print(epoch,train_losses[-1],psnr(train_losses[-1]),psnr(val_acc[-1]),flush=True)

In [None]:
net.to("cuda:0")
train_net(net,train_loader,test_loader,device="cuda:0")

100%|██████████| 409/409 [00:29<00:00, 14.03it/s]


0 0.04405897735756956 13.55965587374056 19.099055017634665


100%|██████████| 409/409 [00:29<00:00, 13.82it/s]


1 0.006398116668781908 21.939478448792606 21.658744664348223


100%|██████████| 409/409 [00:29<00:00, 13.97it/s]


2 0.004811940648112546 23.1767973801469 24.0151132339517


100%|██████████| 409/409 [00:30<00:00, 13.35it/s]


3 0.004136622879394583 23.833540703211643 21.02695702641526


100%|██████████| 409/409 [00:29<00:00, 13.93it/s]


4 0.00353722103238379 24.51337801325947 24.03386183752485


100%|██████████| 409/409 [00:29<00:00, 13.82it/s]


5 0.0033803877470819027 24.710334811543788 19.481594747404433


100%|██████████| 409/409 [00:29<00:00, 13.95it/s]


6 0.0031685953291501934 24.99133222324097 25.49348215655359


100%|██████████| 409/409 [00:28<00:00, 14.16it/s]


7 0.002957441344228744 25.290838600999987 26.040780039094837


100%|██████████| 409/409 [00:29<00:00, 14.04it/s]


8 0.0029992509373147178 25.229871967506654 25.87195847632593


 51%|█████     | 209/409 [00:15<00:13, 14.97it/s]

放大图像与原始图像的比较

In [37]:
from torchvision.utils import save_image
random_test_loader=DataLoader(test_data,batch_size=4,shuffle=True)
it=iter(random_test_loader)
x,y=next(it)

bl_recon=torch.nn.functional.interpolate(x,128,mode="bilinear",align_corners=True)
yp=net(x.to("cuda:0")).to("cpu")
save_image(torch.cat([y,bl_recon,yp],0),"cnn_upscale.jpg",nrow=4)

In [None]:
from IPython.display import Image,display_jpeg
display_jpeg(Image('cnn_upscale.jpg'))