In [None]:
import torch
import os
from PIL import Image
import clip
import os.path as osp
import os, sys
import torchvision.utils as vutils
sys.path.insert(0, '../')

from lib.utils import load_model_weights,mkdir_p
from models.GALIP import NetG, CLIP_TXT_ENCODER

In [None]:
device = 'cpu' # 'cpu' # 'cuda:0'
CLIP_text = "ViT-B/32"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
clip_model = clip_model.eval()

In [None]:
text_encoder = CLIP_TXT_ENCODER(clip_model).to(device)
netG = NetG(64, 100, 512, 256, 3, clip_model).to(device)
path = '../saved_models/bird/state_epoch_150.pth'
checkpoint = torch.load(path, map_location=torch.device('cpu'))
netG = load_model_weights(netG, checkpoint['model']['netG'])

In [None]:
# 每次生成6张图片
batch_size = 6
noise = torch.randn((batch_size, 100)).to(device)

In [None]:
# 定义界面
from PySide6 import QtWidgets
from PySide6.QtCore import Qt
from PySide6 import QtGui

In [None]:
import sys

In [None]:
class MWindow(QtWidgets.QMainWindow):
    def __init__(self):
        super().__init__()

        self.resize(600, 400)
        
        centralWidget = QtWidgets.QWidget(self)
        self.setCentralWidget(centralWidget)
        mainLayout = QtWidgets.QVBoxLayout(centralWidget)

        # 取消自带的标题栏
        self.setWindowFlags(Qt.WindowType.FramelessWindowHint)
        
        # 自定义标题栏
        hlayout = QtWidgets.QHBoxLayout()
        self.title = QtWidgets.QLabel("文本到图像生成器")
        self.minimizeButton = QtWidgets.QPushButton("最小化")
        self.closeButton = QtWidgets.QPushButton("关闭")

        # 设置槽函数
        self.minimizeButton.clicked.connect(self.showMinimized)
        self.closeButton.clicked.connect(self.close)

        # 应用到布局
        hlayout.addWidget(self.title)
        hlayout.addWidget(self.minimizeButton)
        hlayout.addWidget(self.closeButton)
        
        # 界面的上半部分
        self.topLayout = QtWidgets.QHBoxLayout()
        # 文本
        self.textLabel = QtWidgets.QPlainTextEdit(self)
        self.textLabel.setMinimumSize(300, 200)
        self.textLabel.setStyleSheet("border: 1px solid black;")
        # 图像
        self.imageLabel = QtWidgets.QLabel(self)
        self.imageLabel.setMinimumSize(300, 200)
        self.imageLabel.setStyleSheet("border: 1px solid black;")

        # 应用到布局
        self.topLayout.addWidget(self.textLabel)
        self.topLayout.addWidget(self.imageLabel)

        # 界面的下半部分
        self.bottomLayout = QtWidgets.QHBoxLayout()
        # 生成按钮
        self.generateButton = QtWidgets.QPushButton(self)
        self.generateButton.setText("✔生成图片")

        # 应用到布局
        self.bottomLayout.addWidget(self.generateButton)

        # 整体布局
        mainLayout.addLayout(hlayout)
        mainLayout.addLayout(self.topLayout)
        mainLayout.addLayout(self.bottomLayout)

        # 设置槽函数
        self.generateButton.clicked.connect(self.generateImage)
    
    def generateImage(self):
        text = self.textLabel.toPlainText()
        text = text.strip()
        if text == "":
            return
        tokenized_text = clip.tokenize([text]).to(device)
        sent_emb,word_emb = text_encoder(tokenized_text)
        # 重复batch_size次
        sent_emb = sent_emb.repeat(batch_size,1)
        # 获取batch_size个生成图片
        fake_imgs = netG(noise,sent_emb,eval=True).float()
        vutils.save_image(fake_imgs, './samples/%s.png'%(text), nrow=3)
        pic = QtGui.QPixmap('./samples/%s.png'%(text)).scaled(self.imageLabel.width(), self.imageLabel.height())
        self.imageLabel.setPixmap(pic)


In [None]:
app = QtWidgets.QApplication()
Window = MWindow()

In [None]:
Window.show()
sys.exit(app.exec())

In [None]:
captions = ['a small bird with a dark colored body and a brown head.']

In [None]:
mkdir_p('./samples')

In [None]:
# generate from text
with torch.no_grad():
    for i in range(len(captions)):
        caption = captions[i]
        tokenized_text = clip.tokenize([caption]).to(device)
        sent_emb, word_emb = text_encoder(tokenized_text)
        sent_emb = sent_emb.repeat(batch_size,1)
        fake_imgs = netG(noise,sent_emb,eval=True).float()
        print(fake_imgs.shape)
        name = f'{captions[i].replace(" ", "-")}'
        name = name[:len(name)-1]
        vutils.save_image(fake_imgs.data, './samples/%s.png'%(name), nrow=8, value_range=(-1, 1), normalize=True)