In [3]:
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.tester import Tester
from modules.loss import compute_loss
from models.r2gen import R2GenModel
from dataclasses import dataclass
from typing import Optional
from PIL import Image
from torchvision import transforms
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
@dataclass
class Config:
    """配置参数类，对应parse_args()函数的所有参数"""
    # Data input settings
    image_dir: str = 'data/iu_xray/images/'
    ann_path: str = 'data/iu_xray/annotation.json'
    
    # Data loader settings
    dataset_name: str = 'iu_xray'
    max_seq_length: int = 60
    threshold: int = 3
    num_workers: int = 2
    batch_size: int = 16
    
    # Model settings (for visual extractor)
    visual_extractor: str = 'resnet101'
    visual_extractor_pretrained: bool = True
    
    # Model settings (for Transformer)
    d_model: int = 512
    d_ff: int = 512
    d_vf: int = 2048
    num_heads: int = 8
    num_layers: int = 3
    dropout: float = 0.1
    logit_layers: int = 1
    bos_idx: int = 0
    eos_idx: int = 0
    pad_idx: int = 0
    use_bn: int = 0
    drop_prob_lm: float = 0.5
    
    # for Relational Memory
    rm_num_slots: int = 3
    rm_num_heads: int = 8
    rm_d_model: int = 512
    
    # Sample related
    sample_method: str = 'beam_search'
    beam_size: int = 3
    temperature: float = 1.0
    sample_n: int = 1
    group_size: int = 1
    output_logsoftmax: int = 1
    decoding_constraint: int = 0
    block_trigrams: int = 1
    
    # Trainer settings
    n_gpu: int = 1
    epochs: int = 100
    save_dir: str = 'results/iu_xray'
    record_dir: str = 'records/'
    save_period: int = 1
    monitor_mode: str = 'max'
    monitor_metric: str = 'BLEU_4'
    early_stop: int = 50
    
    # Optimization
    optim: str = 'Adam'
    lr_ve: float = 5e-5
    lr_ed: float = 1e-4
    weight_decay: float = 5e-5
    amsgrad: bool = True
    
    # Learning Rate Scheduler
    lr_scheduler: str = 'StepLR'
    step_size: int = 50
    gamma: float = 0.1
    
    # Others
    seed: int = 9233
    resume: Optional[str] = None
    load: Optional[str] = "data/model_iu_xray.pth"

In [8]:
args = Config()

In [13]:
a = torch.load("data/model_iu_xray.pth")
tokenizer = Tokenizer(args)
model = R2GenModel(args, tokenizer)
model.load_state_dict(a['state_dict'])
model.to(torch.device('cuda:0'))



R2GenModel(
  (visual_extractor): VisualExtractor(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplac

In [22]:
image_1 = Image.open("data/iu_xray/images/CXR30_IM-1385/0.png").convert('RGB')
image_2 = Image.open("data/iu_xray/images/CXR30_IM-1385/1.png").convert('RGB')
transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))])
image_1 = transform(image_1)
image_2 = transform(image_2)
image = torch.stack((image_1, image_2), 0)
image = image.unsqueeze(0)
image = image.to(torch.device('cuda:0'))

In [23]:
output = model(image, mode='sample')
reports = model.tokenizer.decode_batch(output.cpu().numpy())
reports

['the cardiomediastinal silhouette is normal in size and contour . no focal consolidation pneumothorax or large pleural effusion . negative for acute bone abnormality .']

In [24]:
model_t = ChatOpenAI(
    model = "deepseek-chat",
    temperature = 0.3,
    api_key = "sk-c8f09fecf95a49f2b7b1456e7fb5f3e9",
    base_url = "https://api.deepseek.com"
)
chat_template = ChatPromptTemplate(
    [
        ("system","你是一个翻译员"),
        ("human","翻译信息如下：{text}")
    ]
)
t_chain = chat_template | model_t
res = t_chain.invoke(reports).content
res

'心纵隔影大小及轮廓正常。未见局灶性实变、气胸或大量胸腔积液。急性骨性病变阴性。'