diff --git a/README.md b/README.md index 8f8901c5..e1c3bab0 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ Fundamental research to develop new architectures for foundation models and A(G) ## News +- December, 2023: [LongNet](torchscale/model/LongNet.py) and [LongViT](examples/longvit/README.md) released - October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet - November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)] @@ -37,6 +38,18 @@ cd torchscale pip install -e . ``` +For faster training install [Flash Attention](https://github.com/Dao-AILab/flash-attention) for Turing, Ampere, Ada, or Hopper GPUs: +``` +pip install flash-attn +``` +or [xFormers](https://github.com/facebookresearch/xformers) for Volta, Turing, Ampere, Ada, or Hopper GPUs: +``` +# cuda 11.8 version +pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118 +# cuda 12.1 version +pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121 +``` + ## Getting Started It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder: @@ -85,6 +98,21 @@ It takes only several lines of code to create a RetNet model: >>> print(retnet) ``` +For LongNet models ([Flash Attention](https://github.com/Dao-AILab/flash-attention) required): +```python +>>> import torch +>>> from torchscale.architecture.config import EncoderConfig, DecoderConfig +>>> from torchscale.model.longnet import LongNetEncoder, LongNetDecoder + +# Creating a LongNet encoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2] +>>> config = EncoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True) +>>> longnet = LongNetEncoder(config) + +# Creating a LongNet decoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2] +>>> config = DecoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True) +>>> longnet = LongNetDecoder(config) +``` + ## Key Features - [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555) @@ -142,6 +170,8 @@ We have examples of how to use TorchScale in the following scenarios/tasks: - Vision + * [LongViT](examples/longvit/README.md) + * ViT/BEiT [In progress] - Speech @@ -228,6 +258,26 @@ If you find this repository useful, please consider citing our work: } ``` +``` +@article{longnet, + author={Jiayu Ding and Shuming Ma and Li Dong and Xingxing Zhang and Shaohan Huang and Wenhui Wang and Nanning Zheng and Furu Wei}, + title = {{LongNet}: Scaling Transformers to 1,000,000,000 Tokens}, + journal = {ArXiv}, + volume = {abs/2307.02486}, + year = {2023} +} +``` + +``` +@article{longvit, + title = {When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology}, + author = {Wenhui Wang and Shuming Ma and Hanwen Xu and Naoto Usuyama and Jiayu Ding and Hoifung Poon and Furu Wei}, + journal = {ArXiv}, + volume = {abs/2312.03558}, + year = {2023} +} +``` + ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/examples/fairseq/README.md b/examples/fairseq/README.md index bb8a71e3..b7b682ff 100644 --- a/examples/fairseq/README.md +++ b/examples/fairseq/README.md @@ -251,6 +251,45 @@ python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ --use-xmoe ``` +### LongNet Model + +```bash +cd examples/fairseq/ +python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \ + ${PATH_TO_DATA} \ + --num-workers 2 \ + --activation-fn gelu \ + --share-decoder-input-output-embed \ + --validate-interval-updates 1000 \ + --save-interval-updates 1000 \ + --no-epoch-checkpoints \ + --memory-efficient-fp16 \ + --fp16-init-scale 4 \ + --arch lm_base \ + --task language_modeling \ + --sample-break-mode none \ + --tokens-per-sample 4096 \ + --optimizer adam --adam-betas "(0.9, 0.98)" \ + --adam-eps 1e-08 \ + --clip-norm 0.0 \ + --lr 5e-4 \ + --lr-scheduler polynomial_decay \ + --warmup-updates 750 \ + --dropout 0.1 \ + --attention-dropout 0.1 \ + --weight-decay 0.01 \ + --batch-size 4 \ + --update-freq 1 \ + --required-batch-size-multiple 1 \ + --total-num-update 50000 \ + --max-update 50000 \ + --seed 1 \ + --ddp-backend=c10d \ + --flash-attention \ + --segment-length [2048,4096] \ + --dilated-ratio [1,2] +``` + ## Example: Machine Translation ### Data Format diff --git a/examples/fairseq/models/language_modeling.py b/examples/fairseq/models/language_modeling.py index 71bf1a52..38cf9e2b 100644 --- a/examples/fairseq/models/language_modeling.py +++ b/examples/fairseq/models/language_modeling.py @@ -25,6 +25,7 @@ from torchscale.architecture.config import DecoderConfig from torchscale.architecture.decoder import Decoder +from torchscale.model.LongNet import LongNetDecoder DEFAULT_MAX_TARGET_POSITIONS = 1024 logger = logging.getLogger(__name__) @@ -196,6 +197,19 @@ class LanguageConfig(FairseqDataclass): xpos_scale_base: Optional[int] = field( default=512, ) + flash_attention: Optional[bool] = field( + default=False, + ) + seq_parallel: Optional[bool] = field( + default=False, + ) + segment_length: Optional[str] = field( + default='', + ) + dilated_ratio: Optional[str] = field( + default='', + ) + @register_model("lm", dataclass=LanguageConfig) @@ -256,7 +270,13 @@ def build_model(cls, args, task): config = DecoderConfig() config.override(args) - decoder = LMDecoder( + if args.segment_length != '': + assert args.dilated_ratio != '' + DECODER_CLASS = LongNetLMDecoder + else: + DECODER_CLASS = LMDecoder + + decoder = DECODER_CLASS( config, embed_tokens, embed_positions, @@ -291,6 +311,25 @@ def reorder_incremental_state_scripting( incremental_state[module][key] = result +class LongNetLMDecoder(LongNetDecoder, FairseqIncrementalDecoder): + def forward(self, src_tokens, **kwargs): + self_attn_padding_mask = src_tokens.eq(self.dictionary.pad()) + return super().forward(src_tokens, self_attn_padding_mask, **kwargs) + + def max_positions(self): + return self.embed_positions.max_positions + + def reorder_incremental_state_scripting( + self, + incremental_state, + new_order, + ): + for module in incremental_state: + for key in incremental_state[module]: + result = incremental_state[module][key].index_select(0, new_order) + incremental_state[module][key] = result + + @register_model_architecture("lm", "lm_base") def base_lm_architecture(args): # backward compatibility for older model checkpoints diff --git a/examples/longvit/README.md b/examples/longvit/README.md new file mode 100644 index 00000000..9e977c8d --- /dev/null +++ b/examples/longvit/README.md @@ -0,0 +1,71 @@ +# [(LongViT) When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology](https://arxiv.org/abs/2312.03558) + +**LongViT** is a vision Transformer that can process gigapixel images (e.g., 32,768x32,768 images) in an end-to-end manner. We split the image into millions of patches and employ [LongNet](https://arxiv.org/abs/2307.02486) to directly model the extremely long sequence. We apply LongViT in the field of computational pathology and achieve remarkable performance on cancer subtyping and survival prediction tasks. + + +## Setup +``` +pip install -r requirements.txt +pip install git+https://github.com/shumingma/fairseq.git@moe +pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.20#egg=xformers +``` + + +## Pretraining + +We perform self-supervised pretraining on TCGA diagnostic slides using [DINO](https://arxiv.org/abs/2104.14294) objective. The detailed instructions can be found at [`get_started_for_tcga_pretraining.md`](get_started/get_started_for_tcga_pretraining.md). + +The link to the pretrained LongViT model on TCGA diagnostic slides: + - [`LongViT`](https://conversationhub.blob.core.windows.net/beit-share-public/longvit/longvit_small_patch32_1024.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=384; FFN factor=4x; #head=16; patch=32x32 + + +## Fine-tuning on Subtyping Classification + +We perform finetuning on cancer subtyping on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_subtyping.md`](get_started/get_started_for_tcga_subtyping.md). + + +## Fine-tuning on Survival Prediction + +We perform finetuning on survival prediction on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_survival_prediction.md`](get_started/get_started_for_tcga_survival_prediction.md). + + +## Citation + +If you find this repository useful, please consider citing our work: +``` +@article{longvit, + title={When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology}, + author={Wang, Wenhui and Ma, Shuming and Xu, Hanwen and Usuyama, Naoto and Ding, Jiayu and Poon, Hoifung and Wei, Furu}, + journal={arXiv preprint arXiv:2312.03558}, + year={2023} +} + +@article{longnet, + title={LongNet: Scaling transformers to 1,000,000,000 tokens}, + author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Zheng, Nanning and Wei, Furu}, + journal={arXiv preprint arXiv:2307.02486}, + year={2023} +} + +@article{torchscale, + title={TorchScale: Transformers at scale}, + author={Ma, Shuming and Wang, Hongyu and Huang, Shaohan and Wang, Wenhui and Chi, Zewen and Dong, Li and Benhaim, Alon and Patra, Barun and Chaudhary, Vishrav and Song, Xia and others}, + journal={arXiv preprint arXiv:2211.13184}, + year={2022} +} +``` + + +## Acknowledgement + +This repository is built using the [BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3), the [MCAT](https://github.com/mahmoodlab/MCAT), the [DINO](https://github.com/facebookresearch/dino), the [HIPT](https://github.com/mahmoodlab/HIPT) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library. + + +## License +This project is licensed under the license found in the LICENSE file in the root directory of this source tree. + +[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct) + +### Contact Information + +For help or issues using LongViT models, please submit a GitHub issue. diff --git a/examples/longvit/data_preprocessing/cache_transformed_images.py b/examples/longvit/data_preprocessing/cache_transformed_images.py new file mode 100644 index 00000000..9a211bcf --- /dev/null +++ b/examples/longvit/data_preprocessing/cache_transformed_images.py @@ -0,0 +1,78 @@ +import os +import sys +import torch +import random +import argparse +from PIL import Image, ImageFilter, ImageOps +from multiprocessing import Pool, cpu_count +from timm.data.transforms import RandomResizedCropAndInterpolation +import torchvision.transforms as transforms + +Image.MAX_IMAGE_PIXELS = 6400000000 + + +def build_transform(input_size): + train_interpolation = "bicubic" + t = [ + RandomResizedCropAndInterpolation(input_size, scale=(0.5, 1.0), interpolation=train_interpolation), + transforms.RandomHorizontalFlip(), + ] + t = transforms.Compose(t) + + return t + + +def pil_loader(path): + with open(path, "rb") as f: + img = Image.open(f) + return img.convert("RGB") + + +def save_image(transformed_img, output_image_path): + if isinstance(transformed_img, torch.Tensor): + transformed_img = transforms.ToPILImage()(transformed_img) + transformed_img.save(output_image_path) + + +def get_image_files(input_dir): + for root, _, files in os.walk(input_dir): + for file in files: + if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): + yield os.path.join(root, file) + + +def transform_and_save_crops(args): + input_path, input_dir, output_dir, transform = args + print(input_path) + file_basename = os.path.basename(input_path) + + img = pil_loader(input_path) + transformed_img = transform(img) + output_image_path = os.path.join(output_dir, file_basename) + save_image(transformed_img, output_image_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Save transformed images in a directory.') + parser.add_argument('input_dir', help='Path to the input directory.') + parser.add_argument('output_dir', help='Path to the output directory.') + parser.add_argument('-p', '--processes', type=int, default=cpu_count(), help='Number of processes to use. Default: number of CPU cores') + parser.add_argument('--input_size', type=int, default=16384, help='input image size') + args = parser.parse_args() + + input_dir = args.input_dir + output_dir = args.output_dir + num_processes = args.processes + input_size = args.input_size + print("num_processes: {}".format(num_processes)) + print("input_size: {}".format(input_size)) + + transform = build_transform(input_size=input_size) + + image_files = list(get_image_files(input_dir)) + task_args = [(file, input_dir, output_dir, transform) for file in image_files] + + os.makedirs(output_dir, exist_ok=True) + + with Pool(processes=num_processes) as pool: + pool.map(transform_and_save_crops, task_args) diff --git a/examples/longvit/data_preprocessing/convert_wsi_to_images.py b/examples/longvit/data_preprocessing/convert_wsi_to_images.py new file mode 100644 index 00000000..53663aa0 --- /dev/null +++ b/examples/longvit/data_preprocessing/convert_wsi_to_images.py @@ -0,0 +1,45 @@ +import os +import glob +import argparse +import openslide + +from PIL import Image +from concurrent.futures import ProcessPoolExecutor + + +def convert_wsi_to_images(slide_path, image_path, target_size, level=0): + slide = openslide.open_slide(slide_path) + level_dims = slide.level_dimensions + region = slide.read_region((0,0), level, level_dims[level]) + region = region.convert("RGB") + print("convert: {}({}) -> {}".format(slide_path, region.size, image_path)) + resized_img = region.resize((target_size, target_size), Image.BICUBIC) + resized_img.save(image_path) + + +def process_slides(input_folder, output_folder, target_size, level=0): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + slide_paths = glob.glob(os.path.join(input_folder, "*.svs")) + + with ProcessPoolExecutor(max_workers=1) as executor: + for slide_path in slide_paths: + image_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0] + ".jpg") + executor.submit(convert_wsi_to_images, slide_path, image_path, target_size, level=level) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert slides into images") + parser.add_argument("input_folder", type=str, help="") + parser.add_argument("output_folder", type=str, help="") + parser.add_argument("target_size", type=int, help="") + parser.add_argument("level", type=int, help="") + + args = parser.parse_args() + input_folder = args.input_folder + output_folder = args.output_folder + target_size = args.target_size + level = args.level + + process_slides(input_folder, output_folder, target_size, level=level) diff --git a/examples/longvit/data_preprocessing/create_tcga_subtyping_index.py b/examples/longvit/data_preprocessing/create_tcga_subtyping_index.py new file mode 100644 index 00000000..be962d8f --- /dev/null +++ b/examples/longvit/data_preprocessing/create_tcga_subtyping_index.py @@ -0,0 +1,37 @@ +from datasets import TCGASubtypingDataset + +tcga_task = "tcga_brca" +for k_fold in range(10): + TCGASubtypingDataset.make_dataset_index( + task=tcga_task, + csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task), + csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold), + k_fold=k_fold, + index_path="./subtyping_split_index/{}".format(tcga_task), + ignore=['MDLC', 'PD', 'ACBC', 'IMMC', 'BRCNOS', 'BRCA', 'SPC', 'MBC', 'MPT'], + label_dict = {'IDC':0, 'ILC':1}, + ) + +tcga_task = "tcga_lung" +for k_fold in range(10): + TCGASubtypingDataset.make_dataset_index( + task=tcga_task, + csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task), + csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold), + k_fold=k_fold, + index_path="./subtyping_split_index/{}".format(tcga_task), + ignore=[], + label_dict = {'LUAD':0, 'LUSC':1}, + ) + +tcga_task = "tcga_kidney" +for k_fold in range(10): + TCGASubtypingDataset.make_dataset_index( + task=tcga_task, + csv_path="./subtyping_dataset_csv/{}_subset.csv.zip".format(tcga_task), + csv_split_path="./subtyping_splits/10foldcv_subtype/{}/splits_{}.csv".format(tcga_task, k_fold), + k_fold=k_fold, + index_path="./subtyping_split_index/{}".format(tcga_task), + ignore=[], + label_dict = {'CCRCC':0, 'PRCC':1, 'CHRCC':2}, + ) diff --git a/examples/longvit/data_preprocessing/create_tcga_survival_index.py b/examples/longvit/data_preprocessing/create_tcga_survival_index.py new file mode 100644 index 00000000..77698169 --- /dev/null +++ b/examples/longvit/data_preprocessing/create_tcga_survival_index.py @@ -0,0 +1,11 @@ +from datasets import TCGASurvivalDataset + +for tcga_task in ["tcga_ucec", "tcga_luad", "tcga_brca"]: + for k_fold in range(5): + TCGASurvivalDataset.make_dataset_index( + task=tcga_task, + csv_path="./survival_dataset_csv/{}_all_clean.csv.zip".format(tcga_task), + csv_split_path="./survival_splits/5foldcv/{}/splits_{}.csv".format(tcga_task, k_fold), + k_fold=k_fold, + index_path="./survival_split_index/{}".format(tcga_task), + ) \ No newline at end of file diff --git a/examples/longvit/data_preprocessing/generate_1024_crops.py b/examples/longvit/data_preprocessing/generate_1024_crops.py new file mode 100644 index 00000000..26de7efb --- /dev/null +++ b/examples/longvit/data_preprocessing/generate_1024_crops.py @@ -0,0 +1,99 @@ +import os +import sys +import cv2 +import json +import numpy as np +import openslide +import time +import torch +import openslide +import argparse +import random +import shutil + +import glob +from concurrent.futures import ProcessPoolExecutor + +from PIL import Image +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD + + +def is_similar_pixel(pixel1, pixel2, threshold=30): + return np.linalg.norm(pixel1 - pixel2) < threshold + + +def should_discard_image(image_path, target_pixel=np.array([243, 243, 243]), threshold=30, similarity_ratio=0.99): + image = cv2.imread(image_path) + height, width, _ = image.shape + + similar_pixels = 0 + total_pixels = height * width + + for y in range(height): + for x in range(width): + pixel = image[y, x] + + if is_similar_pixel(pixel, target_pixel, threshold): + similar_pixels += 1 + + ratio = similar_pixels / total_pixels + return ratio > similarity_ratio + + +def random_crop(slide_path, output_path, min_crop_size, max_crop_size, level=0): + slide = openslide.open_slide(slide_path) + level_dim = slide.level_dimensions + slide_width, slide_height = slide.dimensions + + crop_width = random.randint(min_crop_size, max_crop_size) + crop_height = random.randint(min_crop_size, max_crop_size) + + x = random.randint(0, slide_width - crop_width) + y = random.randint(0, slide_height - crop_height) + + region = slide.read_region((x, y), level, (crop_width, crop_height)) + region = region.convert("RGB") + region.save(output_path) + + +def get_crops(slide_path, output_folder, crop_number, min_crop_size, max_crop_size): + print(slide_path) + + index = 0 + while index < crop_number: + output_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0], f"{str(index).zfill(8)}.JPEG") + + dir_path = os.path.dirname(output_path) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + random_crop(slide_path, output_path, min_crop_size, max_crop_size) + if not should_discard_image(output_path): + index += 1 + + +def process_slides(input_folder, output_folder, crop_number=100, min_crop_size=1024, max_crop_size=1536): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + slide_paths = glob.glob(f"{input_folder}/**/*.svs", recursive=True) + + with ProcessPoolExecutor(max_workers=4) as executor: + for slide_path in slide_paths: + executor.submit(get_crops, slide_path, output_folder, crop_number, min_crop_size, max_crop_size) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate crops from slides") + parser.add_argument("input_folder", type=str, help="") + parser.add_argument("output_folder", type=str, help="") + parser.add_argument("crop_number", type=int, help="") + + args = parser.parse_args() + input_folder = args.input_folder + output_folder = args.output_folder + crop_number = args.crop_number + + process_slides(input_folder, output_folder, crop_number=crop_number) diff --git a/examples/longvit/data_preprocessing/split_to_small_images.py b/examples/longvit/data_preprocessing/split_to_small_images.py new file mode 100644 index 00000000..8bc615a0 --- /dev/null +++ b/examples/longvit/data_preprocessing/split_to_small_images.py @@ -0,0 +1,67 @@ +import os +import json +import shutil +import argparse +from PIL import Image +from concurrent.futures import ProcessPoolExecutor + +Image.MAX_IMAGE_PIXELS = 6400000000 + + +def split_image(image_path, input_folder, output_folder, num_splits): + print(image_path) + file_name, file_ext = os.path.splitext(os.path.basename(image_path)) + + img = Image.open(image_path) + width, height = img.size + + block_width = width + block_height = height // num_splits + + for i in range(num_splits): + left = 0 + upper = i * block_height + right = block_width + lower = (i + 1) * block_height + cropped_img = img.crop((left, upper, right, lower)) + cropped_img.save(f"{output_folder}/{file_name}_{i}{file_ext}") + + +def find_images(input_folder): + image_files = [] + for root, _, files in os.walk(input_folder): + for f in files: + if f.lower().endswith(('.png', '.jpg', '.jpeg')): + image_files.append(os.path.join(root, f)) + return image_files + + +def process_images(image_files, input_folder, output_folder, num_splits, num_processes): + with ProcessPoolExecutor(max_workers=num_processes) as executor: + for image_file in image_files: + executor.submit(split_image, image_file, input_folder, output_folder, num_splits) + + +def main(): + parser = argparse.ArgumentParser(description='Split images into smaller tiles') + parser.add_argument('--input', type=str, required=True, help='Path to the input folder containing images') + parser.add_argument('--output', type=str, required=True, help='Path to the output folder for saving the tiles') + parser.add_argument('--num_splits', type=int, default=16, help='Size of the tiles (default: 4096)') + parser.add_argument('--processes', type=int, default=1, help='Number of processes (default: number of CPU cores)') + args = parser.parse_args() + + input_folder = args.input + output_folder = args.output + num_splits = args.num_splits + num_processes = args.processes + + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + image_files = find_images(input_folder) + process_images(image_files, input_folder, output_folder, num_splits, num_processes) + + +if __name__ == "__main__": + main() + diff --git a/examples/longvit/datasets.py b/examples/longvit/datasets.py new file mode 100644 index 00000000..02880fc4 --- /dev/null +++ b/examples/longvit/datasets.py @@ -0,0 +1,466 @@ +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +import os +import json +import random +import torch +import glob +from collections import defaultdict, Counter +from torchvision import transforms +from torchvision.datasets.folder import default_loader +from torchvision.transforms import InterpolationMode +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.data.transforms import RandomResizedCropAndInterpolation +from timm.data import create_transform + +import utils +import openslide +import pandas as pd +import numpy as np + +import PIL +PIL.Image.MAX_IMAGE_PIXELS = None + + +class BaseDataset(torch.utils.data.Dataset): + def __init__( + self, data_path, split, transform, + task=None, k_fold=0, + ): + index_files = self.get_index_files(split, k_fold=k_fold, task=task) + self.data_path = data_path + items = [] + self.index_files = index_files + + offset = 0 + for _index_file in index_files: + index_file = os.path.join(data_path, _index_file) + with open(index_file, mode="r", encoding="utf-8") as reader: + for line in reader: + data = json.loads(line) + items.append(data) + print("Load %d image-text pairs from %s. " % (len(items) - offset, index_file)) + offset = len(items) + self.items = items + self.loader = default_loader + self.transform = transform + self.split = split + + @staticmethod + def get_index_files(split): + raise NotImplementedError() + + def _get_image(self, image_path: str): + image_path = os.path.join(self.data_path, image_path) + image = self.loader(image_path) + return self.transform(image) + + def __getitem__(self, index: int): + raise NotImplementedError() + + def __len__(self) -> int: + return len(self.items) + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = '{' + "\n Number of items: %s," % self.__len__() + body += "\n data root = %s," % self.data_path + body += "\n split = %s," % self.split + body += "\n dataset index files = %s" % str(self.index_files) + body += "\n transforms = [" + for t in self.transform.transforms: + body += "\n %s" % str(t) + body += "\n ]" + body += "\n}" + + return head + body + + +def _write_data_into_jsonl(items, jsonl_file): + with open(jsonl_file, mode="w", encoding="utf-8") as writer: + for data in items: + writer.write(json.dumps(data, indent=None)) + writer.write('\n') + print("Write %s with %d items !" % (jsonl_file, len(items))) + + +def df_prep(data, label_dict, ignore, label_col): + if label_col != 'label': + data['label'] = data[label_col].copy() + + mask = data['label'].isin(ignore) + data = data[~mask] + data.reset_index(drop=True, inplace=True) + for i in data.index: + key = data.loc[i, 'label'] + data.at[i, 'label'] = label_dict[key] + + return data + + +def get_split_from_df(slide_data, all_splits, prop=1.0, seed=1, split_key='train'): + split = all_splits[split_key].str.rstrip('.svs') + split = split.dropna().reset_index(drop=True) + + if len(split) > 0: + mask = slide_data['slide_id'].isin(split.tolist()) + df_slice = slide_data[mask].reset_index(drop=True) + if split_key == 'train' and prop != 1.0: + df_slice = df_slice.sample(frac=prop, random_state=seed).reset_index(drop=True) + if split_key == 'train': + print(df_slice.head()) + print("Traing Data Size ({%0.2f}): %d" % (prop, df_slice.shape[0])) + else: + df_slice = None + + return df_slice + + +class TCGASubtypingDataset(BaseDataset): + def __init__(self, data_path, split, transform, task, k_fold, image_dir, seq_parallel=False, cached_randaug=False): + super().__init__( + data_path=data_path, split=split, + transform=transform, task=task, k_fold=k_fold, + ) + self.k_fold = k_fold + self.image_dir = image_dir + self.seq_parallel = seq_parallel + self.cached_randaug = cached_randaug + + @staticmethod + def get_index_files(split, k_fold=0, task=None): + if split == "train": + return ("{}.train.index.{}.jsonl".format(task.replace("_subtyping", ""), k_fold), ) + elif split == "val": + return ("{}.val.index.{}.jsonl".format(task.replace("_subtyping", ""), k_fold), ) + elif split == "test": + return ("{}.test.index.{}.jsonl".format(task.replace("_subtyping", ""), k_fold), ) + else: + raise RuntimeError("split %s is not found!" % split) + + def __getitem__(self, index: int): + data = dict() + item = self.items[index] + if self.seq_parallel: + img_path = item["image_path"] + "_{}.jpg".format(utils.get_rank()) + else: + img_path = item["image_path"] + ".jpg" + img = self._get_image(img_path) + data["image"] = img + data["label"] = item["label"] + return data + + def _get_image(self, image_path: str): + if self.cached_randaug: + if self.split == "train": + cur_epoch = int(os.environ.get('cur_epoch')) + image_path = os.path.join(self.image_dir, "epoch_{}".format(cur_epoch), image_path) + else: + image_path = os.path.join(self.image_dir, "wo_augmentation", image_path) + else: + image_path = os.path.join(self.image_dir, image_path) + + image = self.loader(image_path) + return self.transform(image) + + @staticmethod + def _make_tcga_index(task, csv_path, csv_split_path, k_fold, index_path, ignore, label_dict, split): + items = [] + index_file = os.path.join(index_path, f"{task}.{split}.index.{k_fold}.jsonl") + + slide_data = pd.read_csv(csv_path) + slide_data = df_prep(slide_data, label_dict, ignore, label_col="oncotree_code") + slide_data['slide_id'] = slide_data['slide_id'].apply(lambda x: x.replace(".svs", "")) + + all_splits = pd.read_csv(csv_split_path) + slide_data_split = get_split_from_df(slide_data, all_splits, split_key=split) + + for index, row in slide_data_split.iterrows(): + items.append({ + "image_path": row["slide_id"], + "label": row["label"], + }) + file_path = os.path.join(index_path.replace("tcga_", "") + "_svs", "{}.svs".format(row["slide_id"])) + if not os.path.exists(file_path): + print("file {} do not exists".format(row["slide_id"])) + + _write_data_into_jsonl(items, index_file) + + @classmethod + def make_dataset_index(cls, task, csv_path, csv_split_path, k_fold, index_path, ignore, label_dict): + cls._make_tcga_index( + task=task, csv_path=csv_path, csv_split_path=csv_split_path, k_fold=k_fold, index_path=index_path, + ignore=ignore, label_dict=label_dict, split="train", + ) + cls._make_tcga_index( + task=task, csv_path=csv_path, csv_split_path=csv_split_path, k_fold=k_fold, index_path=index_path, + ignore=ignore, label_dict=label_dict, split="val", + ) + cls._make_tcga_index( + task=task, csv_path=csv_path, csv_split_path=csv_split_path, k_fold=k_fold, index_path=index_path, + ignore=ignore, label_dict=label_dict, split="test", + ) + + +def get_survival_split_from_df(slide_data, all_splits, split_key='train'): + split = all_splits[split_key] + split = split.dropna().reset_index(drop=True) + + if len(split) > 0: + mask = slide_data['slide_id'].isin(split.tolist()) + df_slice = slide_data[mask].reset_index(drop=True) + else: + df_slice = None + + return df_slice + + +class TCGASurvivalDataset(BaseDataset): + def __init__(self, data_path, split, transform, task, k_fold, image_dir, seq_parallel=False, cached_randaug=False): + super().__init__( + data_path=data_path, split=split, + transform=transform, task=task, k_fold=k_fold, + ) + self.k_fold = k_fold + self.image_dir = image_dir + self.seq_parallel = seq_parallel + self.cached_randaug = cached_randaug + + @staticmethod + def get_index_files(split, k_fold=0, task=None): + if split == "train": + return ("{}.train.index.{}.jsonl".format(task.replace("_survival", ""), k_fold), ) + elif split == "val": + return ("{}.val.index.{}.jsonl".format(task.replace("_survival", ""), k_fold), ) + elif split == "test": + return ("{}.val.index.{}.jsonl".format(task.replace("_survival", ""), k_fold), ) + else: + raise RuntimeError("split %s is not found!" % split) + + def __getitem__(self, index: int): + data = dict() + item = self.items[index] + if self.seq_parallel: + img_path = item["image_paths"][0].replace(".svs", "") + "_{}.jpg".format(utils.get_rank()) + else: + img_path = item["image_paths"][0].replace(".svs", "") + ".jpg" + img = self._get_image(img_path) + case_id = item["case_id"] + data["image"] = img + data["label"] = item["label"] + data["event_time"] = item["event_time"] + data["censorship"] = item["censorship"] + return data + + def _get_image(self, image_path: str): + if self.cached_randaug: + if self.split == "train": + cur_epoch = int(os.environ.get('cur_epoch')) + image_path = os.path.join(self.image_dir, "epoch_{}".format(cur_epoch), image_path) + else: + image_path = os.path.join(self.image_dir, "wo_augmentation", image_path) + else: + image_path = os.path.join(self.image_dir, image_path) + + image = self.loader(image_path) + return self.transform(image) + + @staticmethod + def _make_tcga_index(task, csv_path, csv_split_path, k_fold, index_path, split): + items = [] + os.makedirs(index_path, exist_ok=True) + index_file = os.path.join(index_path, f"{task}.{split}.index.{k_fold}.jsonl") + + slide_data = pd.read_csv(csv_path, low_memory=False) + if 'case_id' not in slide_data: + slide_data.index = slide_data.index.str[:12] + slide_data['case_id'] = slide_data.index + slide_data = slide_data.reset_index(drop=True) + + label_col = "survival_months" + assert label_col in slide_data.columns + + # if "IDC" in slide_data['oncotree_code'].values: # must be BRCA (and if so, use only IDCs) + # slide_data = slide_data[slide_data['oncotree_code'] == 'IDC'] + + patients_df = slide_data.drop_duplicates(['case_id']).copy() + uncensored_df = patients_df[patients_df['censorship'] < 1] + + n_bins = 4 + eps = 1e-6 + disc_labels, q_bins = pd.qcut(uncensored_df[label_col], q=n_bins, retbins=True, labels=False) + q_bins[-1] = slide_data[label_col].max() + eps + q_bins[0] = slide_data[label_col].min() - eps + + disc_labels, q_bins = pd.cut(patients_df[label_col], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True) + patients_df.insert(2, 'label', disc_labels.values.astype(int)) + + patient_dict = {} + slide_data = slide_data.set_index('case_id') + for patient in patients_df['case_id']: + slide_ids = slide_data.loc[patient, 'slide_id'] + if isinstance(slide_ids, str): + slide_ids = np.array(slide_ids).reshape(-1).tolist() + else: + slide_ids = slide_ids.values.tolist() + patient_dict.update({patient:slide_ids}) + + slide_data = patients_df + slide_data.reset_index(drop=True, inplace=True) + slide_data = slide_data.assign(slide_id=slide_data['case_id']) + + label_dict = {} + key_count = 0 + for i in range(len(q_bins)-1): + for c in [0, 1]: + print('{} : {}'.format((i, c), key_count)) + label_dict.update({(i, c):key_count}) + key_count+=1 + + for i in slide_data.index: + key = slide_data.loc[i, 'label'] + slide_data.at[i, 'disc_label'] = key + censorship = slide_data.loc[i, 'censorship'] + key = (key, int(censorship)) + slide_data.at[i, 'label'] = label_dict[key] + + patients_df = slide_data.drop_duplicates(['case_id']) + patient_data = {'case_id':patients_df['case_id'].values, 'label':patients_df['label'].values} + + new_cols = list(slide_data.columns[-2:]) + list(slide_data.columns[:-2]) + slide_data = slide_data[new_cols] + + all_splits = pd.read_csv(csv_split_path) + slide_data_split = get_survival_split_from_df(slide_data, all_splits, split_key=split) + + for index, row in slide_data_split.iterrows(): + case_id = row["case_id"] + items.append({ + "case_id" : row["case_id"], + "label": row["disc_label"], + "event_time": row["survival_months"], + "censorship": row["censorship"], + "image_paths": patient_dict[case_id], + }) + for slide_id in patient_dict[case_id]: + file_path = os.path.join(f"/tmp/tcga/{task}_svs".replace("tcga_", ""), slide_id) + if not os.path.exists(file_path): + print("file {} do not exists".format(row["slide_id"])) + + _write_data_into_jsonl(items, index_file) + + @classmethod + def make_dataset_index(cls, task, csv_path, csv_split_path, k_fold, index_path): + cls._make_tcga_index( + task=task, csv_path=csv_path, csv_split_path=csv_split_path, k_fold=k_fold, index_path=index_path, + split="train", + ) + cls._make_tcga_index( + task=task, csv_path=csv_path, csv_split_path=csv_split_path, k_fold=k_fold, index_path=index_path, + split="val", + ) + + +task2dataset = { + "tcga_lung_subtyping": TCGASubtypingDataset, + "tcga_kidney_subtyping": TCGASubtypingDataset, + "tcga_brca_subtyping": TCGASubtypingDataset, + "tcga_ucec_survival": TCGASurvivalDataset, + "tcga_luad_survival": TCGASurvivalDataset, + "tcga_brca_survival": TCGASurvivalDataset, +} + + +def create_dataloader(dataset, is_train, batch_size, num_workers, pin_mem, seq_parallel=False, seed=None): + if is_train: + if seq_parallel: + generator = torch.Generator() + generator.manual_seed(seed) + sampler = torch.utils.data.RandomSampler(dataset, generator=generator) + else: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train + ) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + return torch.utils.data.DataLoader( + dataset, sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=is_train, + collate_fn=utils.merge_batch_tensors_by_dict_key, + ) + + +def build_transform(is_train, args): + + if is_train: + t = [] + if args.randaug: + t += [ + RandomResizedCropAndInterpolation(args.input_size, scale=(0.5, 1.0), interpolation=args.train_interpolation), + transforms.RandomHorizontalFlip(), + ] + + t += [ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + t = transforms.Compose(t) + else: + t = transforms.Compose([ + # transforms.Resize((args.input_size, args.input_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + + return t + + +def create_dataset_by_split(args, split, is_train=True): + transform = build_transform(is_train=is_train, args=args) + print(transform) + dataset_class = task2dataset[args.task] + + opt_kwargs = {} + if args.task.startswith("tcga"): + opt_kwargs["k_fold"] = args.k_fold + opt_kwargs["image_dir"] = args.image_dir + opt_kwargs["seq_parallel"] = args.seq_parallel + opt_kwargs["cached_randaug"] = args.cached_randaug + + dataset = dataset_class( + data_path=args.data_path, split=split, + transform=transform, task=args.task, **opt_kwargs, + ) + if is_train: + batch_size = args.batch_size + elif hasattr(args, "eval_batch_size") and args.eval_batch_size is not None: + batch_size = args.eval_batch_size + else: + batch_size = int(args.batch_size * 1.0) + + return create_dataloader( + dataset, is_train=is_train, batch_size=batch_size, + num_workers=args.num_workers, pin_mem=args.pin_mem, + seq_parallel=args.seq_parallel, seed=args.seed, + ) + + +def create_downstream_dataset(args, is_eval=False): + if is_eval: + return create_dataset_by_split(args, split="test", is_train=False) + else: + return \ + create_dataset_by_split(args, split="train", is_train=True), \ + create_dataset_by_split(args, split="val", is_train=False), diff --git a/examples/longvit/engine_for_finetuning.py b/examples/longvit/engine_for_finetuning.py new file mode 100644 index 00000000..78708162 --- /dev/null +++ b/examples/longvit/engine_for_finetuning.py @@ -0,0 +1,321 @@ +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +import math +import sys +import json +import numpy as np +from typing import Iterable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.utils import ModelEma +from timm.utils import accuracy, ModelEma +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from sklearn.preprocessing import label_binarize +from sklearn.metrics import roc_auc_score, roc_curve +from sklearn.metrics import auc as calc_auc +from sksurv.metrics import concordance_index_censored + +import utils + + +class TaskHandler(object): + def __init__(self) -> None: + self.metric_logger = None + self.split = None + + def train_batch(self, model, **kwargs): + raise NotImplementedError() + + def eval_batch(self, model, **kwargs): + raise NotImplementedError() + + def before_eval(self, metric_logger, data_loader, **kwargs): + self.metric_logger = metric_logger + self.split = data_loader.dataset.split + + def after_eval(self, **kwargs): + raise NotImplementedError() + + +class TCGASubtypingHandler(TaskHandler): + def __init__(self, args) -> None: + super().__init__() + if args.label_smoothing > 0.: + self.criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) + else: + self.criterion = torch.nn.CrossEntropyLoss() + self.probs = [] + self.labels = [] + + def train_batch(self, model, image, label): + logits = model(image=image) + return { + "loss": self.criterion(logits, label), + } + + def before_eval(self, metric_logger, data_loader, **kwargs): + self.probs.clear() + self.labels.clear() + self.metric_logger = metric_logger + + def eval_batch(self, model, image, label): + logits = model(image=image) + probs = F.softmax(logits, dim=1) + batch_size = image.shape[0] + acc = (logits.max(-1)[-1] == label).float().sum(0) * 100.0 / batch_size + self.metric_logger.meters['acc'].update(acc.item(), n=batch_size) + self.probs.append(probs) + self.labels.append(label) + + def after_eval(self, data_items, **kwargs): + print('* Acc {acc.global_avg:.3f}'.format(acc=self.metric_logger.acc)) + result_dict = {k: meter.global_avg for k, meter in self.metric_logger.meters.items()} + all_probs = torch.cat(self.probs, dim=0) + all_labels = torch.cat(self.labels, dim=0).tolist() + n_classes = all_probs.size(-1) + if n_classes == 2: + auc_score = roc_auc_score(all_labels, all_probs[:, 1].tolist()) + else: + aucs = [] + binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) + for class_idx in range(n_classes): + if class_idx in all_labels: + fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx].tolist()) + print(calc_auc(fpr, tpr)) + aucs.append(calc_auc(fpr, tpr)) + else: + print('nan') + aucs.append(float('nan')) + + auc_score = np.nanmean(np.array(aucs)) + + patient_results = {} + for index in range(len(all_labels)): + slide_id = data_items[index]["image_path"] + assert all_labels[index] == data_items[index]["label"] + patient_results.update({slide_id: {'prob': all_probs[index, :].tolist(), 'label': all_labels[index]}}) + + result_dict["auc"] = auc_score + print("Acc: {} Auc: {}".format(result_dict["acc"], result_dict["auc"])) + return result_dict, "auc" + + +class TCGASurvivalHandler(TaskHandler): + def __init__(self, args) -> None: + super().__init__() + self.criterion = utils.NLLSurvLoss(alpha=0.0) + self.risk_scores = [] + self.censorships = [] + self.event_times = [] + self.labels = [] + + def train_batch(self, model, image, label, event_time, censorship): + logits = model(image=image) + hazards = torch.sigmoid(logits) + survival = torch.cumprod(1 - hazards, dim=1) + return { + "loss": self.criterion(hazards=hazards, S=survival, Y=label, c=censorship), + } + + def before_eval(self, metric_logger, data_loader, **kwargs): + self.risk_scores.clear() + self.censorships.clear() + self.event_times.clear() + self.labels.clear() + self.metric_logger = metric_logger + + def eval_batch(self, model, image, label, event_time, censorship): + logits = model(image=image) + probs = F.softmax(logits, dim=1) + hazards = torch.sigmoid(logits) + survival = torch.cumprod(1 - hazards, dim=1) + risk = -torch.sum(survival, dim=1) + + batch_size = image.shape[0] + acc = (logits.max(-1)[-1] == label).float().sum(0) * 100.0 / batch_size + self.metric_logger.meters['acc'].update(acc.item(), n=batch_size) + + self.risk_scores.append(risk) + self.censorships.append(censorship) + self.event_times.append(event_time) + self.labels.append(label) + + def after_eval(self, data_items, **kwargs): + print('* Acc {acc.global_avg:.3f}'.format(acc=self.metric_logger.acc)) + result_dict = {k: meter.global_avg for k, meter in self.metric_logger.meters.items()} + all_risk_scores = torch.cat(self.risk_scores, dim=0).cpu().numpy() + all_censorships = torch.cat(self.censorships, dim=0).cpu().numpy() + all_event_times = torch.cat(self.event_times, dim=0).cpu().numpy() + all_labels = torch.cat(self.labels, dim=0).cpu().numpy() + + patient_results = {} + for index in range(len(all_risk_scores)): + case_id = data_items[index]["case_id"] + assert int(all_event_times[index]) == int(data_items[index]["event_time"]) + patient_results.update({case_id: {'case_id': case_id, 'risk': all_risk_scores[index], 'disc_label': all_labels[index], 'survival':all_event_times[index], 'censorship':all_censorships[index]}}) + + c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0] + result_dict["c_index"] = c_index + print("Acc: {} C_Index: {}".format(result_dict["acc"], result_dict["c_index"])) + return result_dict, "c_index" + + +def get_handler(args): + if args.task.endswith("subtyping"): + return TCGASubtypingHandler(args) + elif args.task.endswith("survival"): + return TCGASurvivalHandler(args) + else: + raise NotImplementedError("Sorry, %s is not support." % args.task) + + +def train_one_epoch( + model: torch.nn.Module, data_loader: Iterable, + optimizer: torch.optim.Optimizer, device: torch.device, + handler: TaskHandler, epoch: int, start_steps: int, + lr_schedule_values: list, loss_scaler, max_norm: float = 0, + update_freq: int = 1, model_ema: Optional[ModelEma] = None, + log_writer: Optional[utils.TensorboardLogger] = None, + task = None, seq_parallel = False, +): + model.train(True) + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 10 + + if loss_scaler is None: + model.zero_grad() + model.micro_steps = 0 + else: + optimizer.zero_grad() + + for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + step = data_iter_step // update_freq + global_step = start_steps + step # global training iteration + # Update LR & WD for the first acc + if lr_schedule_values is not None and data_iter_step % update_freq == 0: + for i, param_group in enumerate(optimizer.param_groups): + if lr_schedule_values is not None: + param_group["lr"] = lr_schedule_values[global_step] * param_group["lr_scale"] + # put input data into cuda + for tensor_key in data.keys(): + data[tensor_key] = data[tensor_key].to(device, non_blocking=True) + # print("input %s = %s" % (tensor_key, data[tensor_key])) + if loss_scaler is None and tensor_key.startswith("image"): + data[tensor_key] = data[tensor_key].half() + + if loss_scaler is None: + results = handler.train_batch(model, **data) + else: + with torch.cuda.amp.autocast(): + results = handler.train_batch(model, **data) + + loss = results.pop("loss") + loss_value = loss.item() + + if seq_parallel: + if utils.get_rank() != 0: + loss = loss * 0.0 + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + if loss_scaler is None: + loss /= update_freq + model.backward(loss) + model.step() + + if (data_iter_step + 1) % update_freq == 0: + # model.zero_grad() + # Deepspeed will call step() & model.zero_grad() automatic + if model_ema is not None: + model_ema.update(model) + grad_norm = None + loss_scale_value = utils.get_loss_scale_for_deepspeed(model) + else: + # this attribute is added by timm on one optimizer (adahessian) + is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + loss /= update_freq + grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, + parameters=model.parameters(), create_graph=is_second_order, + update_grad=(data_iter_step + 1) % update_freq == 0) + if (data_iter_step + 1) % update_freq == 0: + optimizer.zero_grad() + if model_ema is not None: + model_ema.update(model) + loss_scale_value = loss_scaler.state_dict()["scale"] + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + metric_logger.update(loss_scale=loss_scale_value) + min_lr = 10. + max_lr = 0. + for group in optimizer.param_groups: + min_lr = min(min_lr, group["lr"]) + max_lr = max(max_lr, group["lr"]) + + metric_logger.update(lr=max_lr) + metric_logger.update(min_lr=min_lr) + weight_decay_value = None + for group in optimizer.param_groups: + if group["weight_decay"] > 0: + weight_decay_value = group["weight_decay"] + metric_logger.update(weight_decay=weight_decay_value) + metric_logger.update(grad_norm=grad_norm) + + if log_writer is not None: + kwargs = { + "loss": loss_value, + } + for key in results: + kwargs[key] = results[key] + log_writer.update(head="train", **kwargs) + + kwargs = { + "loss_scale": loss_scale_value, + "lr": max_lr, + "min_lr": min_lr, + "weight_decay": weight_decay_value, + "grad_norm": grad_norm, + } + log_writer.update(head="opt", **kwargs) + log_writer.set_step() + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device, handler): + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + # switch to evaluation mode + model.eval() + handler.before_eval(metric_logger=metric_logger, data_loader=data_loader) + + for data in metric_logger.log_every(data_loader, 10, header): + for tensor_key in data.keys(): + data[tensor_key] = data[tensor_key].to(device, non_blocking=True) + + with torch.cuda.amp.autocast(): + handler.eval_batch(model=model, **data) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + + return handler.after_eval(data_loader.dataset.items) diff --git a/examples/longvit/get_started/get_started_for_tcga_pretraining.md b/examples/longvit/get_started/get_started_for_tcga_pretraining.md new file mode 100644 index 00000000..b7c40a15 --- /dev/null +++ b/examples/longvit/get_started/get_started_for_tcga_pretraining.md @@ -0,0 +1,15 @@ +# Pretraining LongViT on TCGA using DINO + +## Setup + +1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/). + +2. Generate 1,024x1,024 regions from WSIs: +``` +# we randomly generate 100 small regions for each whole slide image +python data_preprocessing/generate_1024_crops.py /path/to/your_WSIs /path/to/your_crops 100 +``` + +## Pretraining LongViT + +Replace the `vision_transformer.py` in [DINO](https://github.com/facebookresearch/dino) with [LongViT vision_transformer.py](../pretraining/vision_transformer.py), and modify the `global crop size` to 1024 and `local crop size` to 512 to preform LongViT pretraining using DINO framework. \ No newline at end of file diff --git a/examples/longvit/get_started/get_started_for_tcga_subtyping.md b/examples/longvit/get_started/get_started_for_tcga_subtyping.md new file mode 100644 index 00000000..9cbac865 --- /dev/null +++ b/examples/longvit/get_started/get_started_for_tcga_subtyping.md @@ -0,0 +1,169 @@ +# Fine-tuning LongViT on TCGA Subtyping + +## Setup + +1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/), and organize the dataset (e.g., BRCA WSIs) as following structure: + +``` +/path/to/your_WSIs/ + TCGA-3C-AALI-01Z-00-DX1.F6E9A5DF-D8FB-45CF-B4BD-C6B76294C291.svs + ... + TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9.svs + ... +``` + +2. Download [dataset annotation csv](https://github.com/mahmoodlab/HIPT/tree/master/2-Weakly-Supervised-Subtyping/dataset_csv) and [splits for cross validation](https://github.com/mahmoodlab/HIPT/tree/master/2-Weakly-Supervised-Subtyping/splits/10foldcv_subtype) from the HIPT repository. + +3. Generate the index json files of each split using the following command. +``` +# Modify the `csv_path` and `csv_split_path` to your path. +python data_preprocessing/create_tcga_subtyping_index.py +``` + +4. Resize whole slide images to the desired size for finetuning. +``` +python data_preprocessing/convert_wsi_to_images.py /path/to/your_WSIs /path/to/your_resized_WSIs ${target_size} ${wsi_level} +``` + +5. (Optional) For very large images (e.g., 32,768x32,768), we suggest parallelizing the training across multiple GPU devices due to the constraints of computation and memory. We split the sequence of millions of patches along the sequence dimension. +``` +# num_splits is equal to the number of GPUs you used (e.g., 8 in our experiment) +python data_preprocessing/split_to_small_images.py /path/to/your_resized_WSIs /path/to/your_splited_WSIs --num_splits ${num_splits} +``` + +6. (Optional) We find performing image augmentation slightly improves the performance. For very large images (e.g., 32,768x32,768), we perform the augmentation and cache the resulted images of each epoch. +``` +# Run the command 10 times (number of epochs in finetuning) using i from 0-9 +python data_preprocessing/cache_transformed_images.py /path/to/your_resized_WSIs /path/to/your_augmentated_WSIs/epoch_$i --input_size 32768 +``` + +Split these cached images as in step 5 and organize the data as following structure: +``` +/path/to/your_splited_WSIs/ + epoch_0/ + TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_0.jpg + ... + TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_7.jpg + ... + epoch_1/ + TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_0.jpg + ... + TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9_7.jpg + ... + ... + epoch_5/ + ... + epoch_9/ + wo_augmentation/ +``` + + +## Example: Fine-tuning LongViT on TCGA Subtyping + +The LongViT model can be fine-tuned using 8 V100-32GB. For images with a size less than or equal to 16,384x16,384, we can directly perform finetuning without using sequence parallel. + +```bash +# IMAGE_SIZE - {1024, 4096, 8192, 16384} +# TASK - {"brca", "kidney", "lung"} +# K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ + --input_size ${IMAGE_SIZE} \ + --model longvit_small_patch32_${IMAGE_SIZE} \ + --task tcga_${TASK}_subtyping \ + --batch_size 1 \ + --layer_decay 1.0 \ + --lr 5e-5 \ + --update_freq 1 \ + --epochs 10 \ + --warmup_epochs 1 \ + --drop_path 0.1 \ + --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth + --data_path ./subtyping_split_index/tcga_${TASK} \ + --image_dir /path/to/your_resized_WSIs \ + --output_dir /path/to/save/your_model \ + --log_dir /path/to/save/your_model/log \ + --weight_decay 0.05 \ + --seed 42 \ + --save_ckpt_freq 5 \ + --k_fold ${K_FOLD} \ + --num_workers 1 \ + --enable_deepspeed \ + --model_key teacher \ + --randaug +``` +- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). +- `--randaug`: perform image augmentation. + + +Sequence parallel of training on 32,768x32,768 images: + +```bash +# TASK - {"brca", "kidney", "lung"} +# K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ + --input_size 32768 \ + --model longvit_small_patch32_32768 \ + --task tcga_${TASK}_subtyping \ + --batch_size 2 \ + --layer_decay 1.0 \ + --lr 5e-5 \ + --update_freq 4 \ + --epochs 10 \ + --warmup_epochs 1 \ + --drop_path 0.1 \ + --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth + --data_path ./subtyping_split_index/tcga_${TASK} \ + --image_dir /path/to/your_splited_WSIs \ + --output_dir /path/to/save/your_model \ + --log_dir /path/to/save/your_model/log \ + --weight_decay 0.05 \ + --seed 42 \ + --save_ckpt_freq 5 \ + --k_fold ${K_FOLD} \ + --num_workers 1 \ + --enable_deepspeed \ + --model_key teacher \ + --seq_parallel \ + --cached_randaug +``` +- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). +- `--seq_parallel`: parallelize the training for very large images. +- `--cached_randaug`: perform training on the cached augmented images. + + +## Example: Evaluate LongViT on TCGA Subtyping + +```bash +# IMAGE_SIZE - {1024, 4096, 8192, 16384, 32768} +# TASK - {"brca", "kidney", "lung"} +# K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +python -m torch.distributed.launch --nproc_per_node=1 run_longvit_finetuning.py \ + --input_size ${IMAGE_SIZE} \ + --model longvit_small_patch32_${IMAGE_SIZE} \ + --task tcga_${TASK}_subtyping \ + --batch_size 1 \ + --layer_decay 1.0 \ + --lr 5e-5 \ + --update_freq 1 \ + --epochs 10 \ + --warmup_epochs 1 \ + --drop_path 0.1 \ + --finetune /path/to/save/your_model/checkpoint-best/mp_rank_00_model_states.pt \ + --data_path ./subtyping_split_index/tcga_${TASK} \ + --image_dir /path/to/your_resized_WSIs \ + --output_dir /path/to/save/your_model \ + --log_dir /path/to/save/your_model/log \ + --weight_decay 0.05 \ + --seed 42 \ + --save_ckpt_freq 5 \ + --k_fold ${K_FOLD} \ + --num_workers 1 \ + --enable_deepspeed \ + --model_key module \ + --eval \ + --no_auto_resume +``` +- `--eval`: performing evaluation on test set. +- `--finetune`: best val model used for test. + +For the model trained with sequence parallel, add `--seq_parallel` and use the same number of GPUs as training to perform evaluation. \ No newline at end of file diff --git a/examples/longvit/get_started/get_started_for_tcga_survival_prediction.md b/examples/longvit/get_started/get_started_for_tcga_survival_prediction.md new file mode 100644 index 00000000..8f9bf686 --- /dev/null +++ b/examples/longvit/get_started/get_started_for_tcga_survival_prediction.md @@ -0,0 +1,141 @@ +# Fine-tuning LongViT on TCGA Survival Prediction + +## Setup + +1. Download TCGA diagnostic whole slides from [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/), and organize the dataset (e.g., BRCA WSIs) as following structure: + +``` +/path/to/your_WSIs/ + TCGA-3C-AALI-01Z-00-DX1.F6E9A5DF-D8FB-45CF-B4BD-C6B76294C291.svs + ... + TCGA-4H-AAAK-01Z-00-DX1.ABF1B042-1970-4E28-8671-43AAD393D2F9.svs + ... +``` + +2. Download [dataset annotation csv](https://github.com/mahmoodlab/MCAT/tree/master/datasets_csv_sig) and [splits for cross validation](https://github.com/mahmoodlab/MCAT/tree/master/splits/5foldcv) from the MCAT repository. + +3. Generate the index json files of each split using the following command. +``` +# Modify the `csv_path` and `csv_split_path` to your path. +python data_preprocessing/create_tcga_survival_index.py +``` + +4. Resize whole slide images to the desired size for finetuning. +``` +python data_preprocessing/convert_wsi_to_images.py /path/to/your_WSIs /path/to/your_resized_WSIs ${target_size} ${wsi_level} +``` + +5. (Optional) For very large images (e.g., 32,768x32,768), we suggest parallelizing the training across multiple GPU devices due to the constraints of computation and memory. We split the sequence of millions of patches along the sequence dimension. +``` +# num_splits is equal to the number of GPUs you used (e.g., 8 in our experiment) +python data_preprocessing/split_to_small_images.py /path/to/your_resized_WSIs /path/to/your_splited_WSIs --num_splits ${num_splits} +``` + + +## Example: Fine-tuning LongViT on TCGA Survival Prediction + +The LongViT model can be fine-tuned using 8 V100-32GB. For images with a size less than or equal to 16,384x16,384, we can directly perform finetuning without using sequence parallel. + +```bash +# IMAGE_SIZE - {1024, 4096, 8192, 16384} +# TASK - {"brca", "kidney", "lung"} +# K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ + --input_size ${IMAGE_SIZE} \ + --model longvit_small_patch32_${IMAGE_SIZE} \ + --task tcga_${TASK}_survival \ + --batch_size 1 \ + --layer_decay 1.0 \ + --lr 5e-5 \ + --update_freq 1 \ + --epochs 10 \ + --warmup_epochs 1 \ + --drop_path 0.1 \ + --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth + --data_path ./survival_split_index/tcga_${TASK} \ + --image_dir /path/to/your_resized_WSIs \ + --output_dir /path/to/save/your_model \ + --log_dir /path/to/save/your_model/log \ + --weight_decay 0.05 \ + --seed 42 \ + --save_ckpt_freq 5 \ + --k_fold ${K_FOLD} \ + --num_workers 1 \ + --enable_deepspeed \ + --model_key teacher \ + --randaug +``` +- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). +- `--randaug`: perform image augmentation. + + +Parallelize the training of 32,768x32,768 images: + +```bash +# TASK - {"brca", "kidney", "lung"} +# K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +python -m torch.distributed.launch --nproc_per_node=8 run_longvit_finetuning.py \ + --input_size 32768 \ + --model longvit_small_patch32_32768 \ + --task tcga_${TASK}_survival \ + --batch_size 2 \ + --layer_decay 1.0 \ + --lr 5e-5 \ + --update_freq 4 \ + --epochs 10 \ + --warmup_epochs 1 \ + --drop_path 0.1 \ + --finetune /your_longvit_model_path/longvit_small_patch32_1024.pth + --data_path ./subtyping_split_index/tcga_${TASK} \ + --image_dir /path/to/your_splited_WSIs \ + --output_dir /path/to/save/your_model \ + --log_dir /path/to/save/your_model/log \ + --weight_decay 0.05 \ + --seed 42 \ + --save_ckpt_freq 5 \ + --k_fold ${K_FOLD} \ + --num_workers 1 \ + --enable_deepspeed \ + --model_key teacher \ + --seq_parallel +``` +- `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretraining). +- `--seq_parallel`: parallelize the training for very large images. + + +## Example: Evaluate LongViT on TCGA Subtyping + +```bash +# IMAGE_SIZE - {1024, 4096, 8192, 16384, 32768} +# TASK - {"brca", "kidney", "lung"} +# K_FOLD - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +python -m torch.distributed.launch --nproc_per_node=1 run_longvit_finetuning.py \ + --input_size ${IMAGE_SIZE} \ + --model longvit_small_patch32_${IMAGE_SIZE} \ + --task tcga_${TASK}_survival \ + --batch_size 1 \ + --layer_decay 1.0 \ + --lr 5e-5 \ + --update_freq 1 \ + --epochs 10 \ + --warmup_epochs 1 \ + --drop_path 0.1 \ + --finetune /path/to/save/your_model/checkpoint-best/mp_rank_00_model_states.pt \ + --data_path ./survival_split_index/tcga_${TASK} \ + --image_dir /path/to/your_resized_WSIs \ + --output_dir /path/to/save/your_model \ + --log_dir /path/to/save/your_model/log \ + --weight_decay 0.05 \ + --seed 42 \ + --save_ckpt_freq 5 \ + --k_fold ${K_FOLD} \ + --num_workers 1 \ + --enable_deepspeed \ + --model_key module \ + --eval \ + --no_auto_resume +``` +- `--eval`: performing evaluation. +- `--finetune`: best val model. + +For the model trained with sequence parallel, add `--seq_parallel` and use the same number of GPUs as training to perform evaluation. \ No newline at end of file diff --git a/examples/longvit/longvit.py b/examples/longvit/longvit.py new file mode 100644 index 00000000..58db485b --- /dev/null +++ b/examples/longvit/longvit.py @@ -0,0 +1,240 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import utils +import torch +import torch.nn as nn + +from torchscale.architecture.encoder import Encoder +from torchscale.model.LongNet import LongNetEncoder +from torchscale.architecture.config import EncoderConfig +from timm.models.layers import trunc_normal_ as __call_trunc_normal_ + + +def trunc_normal_(tensor, mean=0., std=1.): + __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class LongViT(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, flash_attention=True, dilated_ratio="[1,2,4,8,16]", segment_length="[64,128,256,512,1024]", checkpoint_activations=False, seq_parallel=False, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + if img_size == 4096: + segment_length = "[1024,2048,4096,8192,16384]" + elif img_size == 8192: + segment_length = "[1024,4096,8192,16384,65536]" + elif img_size == 16384: + segment_length = "[1024,4096,16384,65536,262144]" + elif img_size == 32768: + segment_length = "[1024,4096,32768,262144,1048576]" + + self.seq_parallel = seq_parallel + encoder_config = EncoderConfig( + img_size=img_size, patch_size=patch_size, vocab_size=64010, multiway=False, + layernorm_embedding=False, normalize_output=False, no_output_layer=True, + drop_path_rate=drop_path_rate, encoder_embed_dim=embed_dim, encoder_attention_heads=num_heads, + encoder_ffn_embed_dim=int(embed_dim * mlp_ratio), encoder_layers=depth, + checkpoint_activations=checkpoint_activations, flash_attention=flash_attention, + dilated_ratio=dilated_ratio, segment_length=segment_length, seq_parallel=seq_parallel, + ) + if flash_attention: + print("Using Torchscale LoneNetEncoder") + print("segment_length: {}".format(encoder_config.segment_length)) + print("dilated_ratio: {}".format(encoder_config.dilated_ratio)) + print("checkpoint_activations: {}".format(encoder_config.checkpoint_activations)) + print("drop_path_rate: {}".format(encoder_config.drop_path_rate)) + self.encoder = LongNetEncoder(encoder_config, embed_tokens=None, embed_positions=None, + output_projection=None, is_encoder_decoder=False,) + else: + print("Using Torchscale Encoder") + self.encoder = Encoder(encoder_config, embed_tokens=None, embed_positions=None, + output_projection=None, is_encoder_decoder=False,) + + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + patch_pos_embed = self.pos_embed + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add positional encoding to each token + if self.seq_parallel: + rank_seq_len = x.size(1) + cur_rank = utils.get_rank() + start_idx = cur_rank * rank_seq_len + end_idx = (cur_rank + 1) * rank_seq_len + x = x + self.pos_embed[:, start_idx:end_idx, :] + else: + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + x = self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"] + return x diff --git a/examples/longvit/modeling_finetune.py b/examples/longvit/modeling_finetune.py new file mode 100644 index 00000000..c2994877 --- /dev/null +++ b/examples/longvit/modeling_finetune.py @@ -0,0 +1,223 @@ +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +import utils +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from timm.models.registry import register_model +from functools import partial +from longvit import LongViT +from torchscale.architecture.config import EncoderConfig +from timm.models.layers import trunc_normal_ as __call_trunc_normal_ + + +def _get_small_config( + img_size=1024, patch_size=32, drop_path_rate=0, + checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs +): + return EncoderConfig( + img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=False, + layernorm_embedding=False, normalize_output=False, no_output_layer=True, + drop_path_rate=drop_path_rate, encoder_embed_dim=384, encoder_attention_heads=16, + encoder_ffn_embed_dim=int(384 * mlp_ratio), encoder_layers=12, + checkpoint_activations=checkpoint_activations, + ) + + +def trunc_normal_(tensor, mean=0., std=1.): + __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) + + +class LongViTForTCGAClassification(nn.Module): + def __init__( + self, + args, + num_classes, + norm_layer=nn.LayerNorm, + seq_parallel=False, + **kwargs + ): + super().__init__() + self.model = LongViT( + img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.encoder_embed_dim, + depth=args.encoder_layers, num_heads=args.encoder_attention_heads, + mlp_ratio=4, drop_path_rate=args.drop_path_rate, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + checkpoint_activations=args.checkpoint_activations, seq_parallel=seq_parallel + ) + embed_dim = args.encoder_embed_dim + self.depth = args.encoder_layers + self.fc_norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.fc_norm.apply(self._init_weights) + self.head.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return self.depth + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.pos_embed'} + + def forward(self, image, **kwargs): + x = self.model(image) + t = x[:, :, :] + cls_x = self.fc_norm(t.mean(1)) + return self.head(cls_x) + + +class LongViTForTCGAClassificationSeqParallel(nn.Module): + def __init__( + self, + args, + num_classes, + norm_layer=nn.LayerNorm, + seq_parallel=False, + **kwargs + ): + super().__init__() + self.model = LongViT( + img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.encoder_embed_dim, + depth=args.encoder_layers, num_heads=args.encoder_attention_heads, + mlp_ratio=4, drop_path_rate=args.drop_path_rate, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + checkpoint_activations=args.checkpoint_activations, seq_parallel=seq_parallel, + ) + embed_dim = args.encoder_embed_dim + self.depth = args.encoder_layers + self.fc_norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.fc_norm.apply(self._init_weights) + self.head.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return self.depth + + @torch.jit.ignore + def no_weight_decay(self): + return {'model.pos_embed'} + + def forward(self, image, **kwargs): + x = self.model(image) + t = x[:, :, :].contiguous() + gatherd_t = utils.gather_tcga_features(t) + cls_x = self.fc_norm(gatherd_t.mean(1)) + return self.head(cls_x) + + +@register_model +def longvit_small_patch32_1024_tcga_subtyping(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=1024, patch_size=32, **kwargs) + if task == "tcga_kidney_subtyping": + model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) + else: + model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) + return model + + +@register_model +def longvit_small_patch32_4096_tcga_subtyping(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=4096, patch_size=32, **kwargs) + if task == "tcga_kidney_subtyping": + model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) + else: + model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) + return model + + +@register_model +def longvit_small_patch32_8192_tcga_subtyping(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=8192, patch_size=32, **kwargs) + args.checkpoint_activations = True + if task == "tcga_kidney_subtyping": + model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) + else: + model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) + return model + + +@register_model +def longvit_small_patch32_16384_tcga_subtyping(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=16384, patch_size=32, **kwargs) + args.checkpoint_activations = True + if task == "tcga_kidney_subtyping": + model = LongViTForTCGAClassification(args, num_classes=3, **kwargs) + else: + model = LongViTForTCGAClassification(args, num_classes=2, **kwargs) + return model + + +@register_model +def longvit_small_patch32_32768_tcga_subtyping(pretrained=False, task=None, seq_parallel=False, **kwargs): + args = _get_small_config(img_size=32768, patch_size=32, **kwargs) + args.checkpoint_activations = True + if task == "tcga_kidney_subtyping": + model = LongViTForTCGAClassificationSeqParallel(args, num_classes=3, seq_parallel=seq_parallel, **kwargs) + else: + model = LongViTForTCGAClassificationSeqParallel(args, num_classes=2, seq_parallel=seq_parallel, **kwargs) + return model + + +@register_model +def longvit_small_patch32_1024_tcga_survival(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=1024, patch_size=32, **kwargs) + model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) + return model + + +@register_model +def longvit_small_patch32_4096_tcga_survival(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=4096, patch_size=32, **kwargs) + model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) + return model + + +@register_model +def longvit_small_patch32_8192_tcga_survival(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=8192, patch_size=32, **kwargs) + args.checkpoint_activations = True + model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) + return model + + +@register_model +def longvit_small_patch32_16384_tcga_survival(pretrained=False, task=None, **kwargs): + args = _get_small_config(img_size=16384, patch_size=32, **kwargs) + args.checkpoint_activations = True + model = LongViTForTCGAClassification(args, num_classes=4, **kwargs) + return model + + +@register_model +def longvit_small_patch32_32768_tcga_survival(pretrained=False, task=None, seq_parallel=False, **kwargs): + args = _get_small_config(img_size=32768, patch_size=32, **kwargs) + args.checkpoint_activations = True + model = LongViTForTCGAClassificationSeqParallel(args, num_classes=4, seq_parallel=seq_parallel, **kwargs) + return model diff --git a/examples/longvit/optim_factory.py b/examples/longvit/optim_factory.py new file mode 100644 index 00000000..744c8c2a --- /dev/null +++ b/examples/longvit/optim_factory.py @@ -0,0 +1,128 @@ +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +from torch import optim as optim +from timm.optim.lookahead import Lookahead + +import json + + +def get_num_layer_for_vit(var_name, num_max_layer): + if "embed" in var_name: + return 0 + elif var_name in ( + "cls_token", "mask_token", "pos_embed", "model.pos_embed", "language_pos_embed", + "word_embeddings.weight", "vision_cls_token", "vision_pos_embed" + ): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("rel_pos_bias"): + return num_max_layer - 1 + elif "layers." in var_name: + layer_id = int(var_name.split('layers.')[1].split('.')[0]) + return layer_id + 1 + else: + return num_max_layer - 1 + + +def get_is_head_flag_for_vit(var_name, num_max_layer): + if var_name.startswith("head"): + return 1 + # elif var_name.startswith("pooler"): + # return 1 + else: + return 0 + + +class LayerDecayValueAssigner(object): + def __init__(self, values, scale_handler=None): + self.scale_handler = scale_handler or get_num_layer_for_vit + self.values = values + + def get_scale(self, layer_id): + return self.values[layer_id] + + def get_layer_id(self, var_name): + return self.scale_handler(var_name, len(self.values)) + + +# The implementation code is modified from Timm (https://github.com/huggingface/pytorch-image-models/tree/main/timm +def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): + parameter_group_names = {} + parameter_group_vars = {} + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + group_name = "no_decay" + this_weight_decay = 0. + else: + group_name = "decay" + this_weight_decay = weight_decay + if get_num_layer is not None: + layer_id = get_num_layer(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + else: + layer_id = None + + if group_name not in parameter_group_names: + if get_layer_scale is not None: + scale = get_layer_scale(layer_id) + else: + scale = 1. + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()) + + +def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): + opt_lower = args.opt.lower() + weight_decay = args.weight_decay + if weight_decay and filter_bias_and_bn: + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) + weight_decay = 0. + else: + parameters = model.parameters() + + opt_args = dict(lr=args.lr, weight_decay=weight_decay) + if hasattr(args, 'opt_eps') and args.opt_eps is not None: + opt_args['eps'] = args.opt_eps + if hasattr(args, 'opt_betas') and args.opt_betas is not None: + opt_args['betas'] = args.opt_betas + + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, **opt_args) + else: + raise ValueError("Invalid optimizer") + + if len(opt_split) > 1: + if opt_split[0] == 'lookahead': + optimizer = Lookahead(optimizer) + + return optimizer diff --git a/examples/longvit/pretraining/vision_transformer.py b/examples/longvit/pretraining/vision_transformer.py new file mode 100644 index 00000000..675ee8a9 --- /dev/null +++ b/examples/longvit/pretraining/vision_transformer.py @@ -0,0 +1,266 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import torch +import torch.nn as nn + +from utils import trunc_normal_ +from torchscale.architecture.encoder import Encoder +from torchscale.model.LongNet import LongNetEncoder +from torchscale.architecture.config import EncoderConfig + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=1024, patch_size=32, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, flash_attention=True, dilated_ratio="[1,2,4,8,16]", segment_length="[64,128,256,512,1024]", checkpoint_activations=False, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + encoder_config = EncoderConfig( + img_size=img_size, patch_size=patch_size, vocab_size=64010, multiway=False, + layernorm_embedding=False, normalize_output=False, no_output_layer=True, + drop_path_rate=drop_path_rate, encoder_embed_dim=embed_dim, encoder_attention_heads=num_heads, + encoder_ffn_embed_dim=int(embed_dim * mlp_ratio), encoder_layers=depth, + checkpoint_activations=checkpoint_activations, flash_attention=flash_attention, + dilated_ratio=dilated_ratio, segment_length=segment_length, seq_parallel=False, + ) + if flash_attention: + print("Using Torchscale LoneNetEncoder") + self.encoder = LongNetEncoder(encoder_config, embed_tokens=None, embed_positions=None, + output_projection=None, is_encoder_decoder=False,) + else: + print("Using Torchscale Encoder") + self.encoder = Encoder(encoder_config, embed_tokens=None, embed_positions=None, + output_projection=None, is_encoder_decoder=False,) + + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + patch_pos_embed = self.pos_embed + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + x = self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"] + x = self.norm(x) + t = x[:, :, :] + cls_x = t.mean(1) + return cls_x + + +def vit_small(patch_size=32, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/examples/longvit/requirements.txt b/examples/longvit/requirements.txt new file mode 100644 index 00000000..46079564 --- /dev/null +++ b/examples/longvit/requirements.txt @@ -0,0 +1,16 @@ +torch==2.0.0 +timm==0.6.13 +Pillow==10.0.0 +blobfile==2.0.2 +mypy==1.4.1 +numpy==1.22.4 +pytest==7.2.2 +requests==2.31.0 +einops==0.6.1 +tensorboardX==1.8 +scipy==1.6.3 +ftfy==6.1.1 +opencv-python==4.8.0.74 +pyarrow==9.0.0 +transformers==4.8.1 +deepspeed==0.4.0 \ No newline at end of file diff --git a/examples/longvit/run_longvit_finetuning.py b/examples/longvit/run_longvit_finetuning.py new file mode 100644 index 00000000..d2c8b6e3 --- /dev/null +++ b/examples/longvit/run_longvit_finetuning.py @@ -0,0 +1,365 @@ +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +import argparse +import datetime +import numpy as np +import time +import torch +import torch.backends.cudnn as cudnn +import json +import os + +from pathlib import Path + +from timm.models import create_model +from timm.utils import ModelEma +from optim_factory import create_optimizer, get_parameter_groups, \ + LayerDecayValueAssigner + +from engine_for_finetuning import train_one_epoch, get_handler, evaluate +from datasets import create_downstream_dataset +from utils import NativeScalerWithGradNormCount as NativeScaler +import utils +import modeling_finetune + + +def get_args(): + parser = argparse.ArgumentParser('LongViT fine-tuning and evaluation script', add_help=False) + + # Model parameters + parser.add_argument('--model', default='longvit_small_patch32_1024', type=str, metavar='MODEL', + help='Name of model to train') + parser.add_argument('--task', type=str, required=True, + choices=['tcga_brca_subtyping', 'tcga_lung_subtyping', 'tcga_kidney_subtyping', + 'tcga_ucec_survival', 'tcga_luad_survival', 'tcga_brca_survival'], + help='Name of task to fine-tuning') + parser.add_argument('--k_fold', type=int, default=0) + + parser.add_argument('--input_size', default=1024, type=int, + help='images input size') + parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', + help='Drop path rate (default: 0.1)') + + parser.add_argument('--checkpoint_activations', action='store_true', default=False, + help='Enable checkpointing to save your memory.') + + parser.add_argument('--seq_parallel', action='store_true', default=False, + help='Enable sequence parallel.') + + parser.add_argument('--model_ema', action='store_true', default=False) + parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') + parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt_betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: 0.9, 0.999, use opt default)') + parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--weight_decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + + parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', + help='learning rate (default: 5e-4)') + parser.add_argument('--layer_decay', type=float, default=0.9) + + parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') + parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', + help='num of steps to warmup LR, will overload warmup_epochs if set > 0') + + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--eval_batch_size', default=None, type=int) + parser.add_argument('--epochs', default=10, type=int) + parser.add_argument('--update_freq', default=1, type=int) + parser.add_argument('--save_ckpt_freq', default=5, type=int) + + # Augmentation parameters + parser.add_argument('--randaug', action='store_true', default=False) + parser.add_argument('--cached_randaug', action='store_true', default=False, help='Using cached augmented images for sequence parallel') + parser.add_argument('--train_interpolation', type=str, default='bicubic', + help='Training interpolation (random, bilinear, bicubic default: "bicubic")') + + # Finetuning params + parser.add_argument('--finetune', default='', + help='finetune from checkpoint') + parser.add_argument('--model_key', default='model|module', type=str) + parser.add_argument('--model_prefix', default='', type=str) + + # Dataset parameters + parser.add_argument('--data_path', default='', type=str, + help='index path') + parser.add_argument('--image_dir', default='', type=str, + help='slide images path') + + parser.add_argument('--output_dir', default='', + help='path where to save, empty for no saving') + parser.add_argument('--log_dir', default=None, + help='path where to tensorboard log') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', + help='resume from checkpoint') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') + parser.set_defaults(auto_resume=True) + + parser.add_argument('--save_ckpt', action='store_true') + parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') + parser.set_defaults(save_ckpt=True) + + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + + # label smoothing + parser.add_argument('--label_smoothing', type=float, default=0.1) + + # deepspeed parameters + parser.add_argument('--enable_deepspeed', action='store_true', default=False) + parser.add_argument('--initial_scale_power', type=int, default=16) + parser.add_argument('--zero_stage', default=0, type=int, + help='ZeRO optimizer stage (default: 0)') + + known_args, _ = parser.parse_known_args() + + if known_args.enable_deepspeed: + try: + import deepspeed + from deepspeed import DeepSpeedConfig + parser = deepspeed.add_config_arguments(parser) + ds_init = deepspeed.initialize + except: + print("Please 'pip install deepspeed==0.4.0'") + exit(0) + else: + ds_init = None + + return parser.parse_args(), ds_init + + +def main(args, ds_init): + utils.init_distributed_mode(args) + + if ds_init is not None: + utils.create_ds_config(args) + + print(args) + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + # random.seed(seed) + + cudnn.benchmark = True + + if utils.get_rank() == 0 and args.log_dir is not None: + os.makedirs(args.log_dir, exist_ok=True) + log_writer = utils.TensorboardLogger(log_dir=args.log_dir) + else: + log_writer = None + + data_loader_train, data_loader_val = create_downstream_dataset(args) + + if not args.model.endswith(args.task): + if args.task.endswith("subtyping"): + model_config = "%s_tcga_subtyping" % args.model + elif args.task.endswith("survival"): + model_config = "%s_tcga_survival" % args.model + else: + model_config = "%s_%s" % (args.model, args.task) + else: + model_config = args.model + print("model_config = %s" % model_config) + model = create_model( + model_config, + pretrained=False, + task=args.task, + drop_path_rate=args.drop_path, + checkpoint_activations=args.checkpoint_activations, + seq_parallel=args.seq_parallel, + ) + + if args.finetune: + utils.load_model_and_may_interpolate(args.finetune, model, args.model_key, args.model_prefix, args.eval) + + model.to(device) + + model_ema = None + if args.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + model_ema = ModelEma( + model, + decay=args.model_ema_decay, + device='cpu' if args.model_ema_force_cpu else '', + resume='') + print("Using EMA with decay = %.8f" % args.model_ema_decay) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print('number of params:', n_parameters) + + if args.seq_parallel: + total_batch_size = args.batch_size * args.update_freq + num_training_steps_per_epoch = len(data_loader_train) // args.update_freq + if args.update_freq > 1 and len(data_loader_train) % args.update_freq != 0: + num_training_steps_per_epoch += 1 + else: + total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() + num_training_steps_per_epoch = len(data_loader_train) + print("LR = %.8f" % args.lr) + print("Batch size = %d" % total_batch_size) + print("Update frequent = %d" % args.update_freq) + print("Number of training examples = %d" % len(data_loader_train.dataset)) + print("Number of training training per epoch = %d" % num_training_steps_per_epoch) + + num_layers = model_without_ddp.get_num_layers() + if args.layer_decay < 1.0: + lrs = list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)) + assigner = LayerDecayValueAssigner(lrs) + else: + assigner = None + + if assigner is not None: + print("Assigned values = %s" % str(assigner.values)) + + skip_weight_decay_list = model.no_weight_decay() + + if args.distributed: + torch.distributed.barrier() + if args.enable_deepspeed: + loss_scaler = None + optimizer_params = get_parameter_groups( + model, args.weight_decay, skip_weight_decay_list, + assigner.get_layer_id if assigner is not None else None, + assigner.get_scale if assigner is not None else None) + model, optimizer, _, _ = ds_init( + args=args, model=model, model_parameters=optimizer_params, + dist_init_required=not args.distributed, + ) + + print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps()) + assert model.gradient_accumulation_steps() == args.update_freq + else: + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + model_without_ddp = model.module + + optimizer = create_optimizer( + args, model_without_ddp, skip_list=skip_weight_decay_list, + get_num_layer=assigner.get_layer_id if assigner is not None else None, + get_layer_scale=assigner.get_scale if assigner is not None else None) + loss_scaler = NativeScaler() + + lr_schedule_values = utils.cosine_scheduler( + args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, + warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, + ) + + utils.auto_load_model( + args=args, model=model, model_without_ddp=model_without_ddp, + optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) + + task_handler = get_handler(args) + + if args.eval: + data_loader_test = create_downstream_dataset(args, is_eval=True) + if args.task in ["tcga_brca_subtyping", "tcga_lung_subtyping", "tcga_kidney_subtyping", + 'tcga_ucec_survival', 'tcga_luad_survival', 'tcga_brca_survival']: + ext_test_stats, task_key = evaluate(data_loader_test, model, device, task_handler) + print(f"Accuracy of the network on the {len(data_loader_test.dataset)} test images: {ext_test_stats[task_key]:.3f}%") + exit(0) + else: + raise NotImplementedError() + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + os.environ['cur_epoch'] = str(epoch) + if args.distributed and not args.seq_parallel: + data_loader_train.sampler.set_epoch(epoch) + if log_writer is not None: + log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) + train_stats = train_one_epoch( + model, data_loader_train, optimizer, device, task_handler, epoch, + epoch * num_training_steps_per_epoch, lr_schedule_values, loss_scaler, + args.clip_grad, args.update_freq, model_ema, log_writer, args.task, args.seq_parallel, + ) + if args.output_dir and args.save_ckpt: + if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: + utils.save_model( + args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) + if data_loader_val is not None: + val_stats, task_key = evaluate(data_loader_val, model, device, task_handler) + + print(f"Performance of the network on the {len(data_loader_val.dataset)} val images: {val_stats[task_key]:.1f}%") + if max_accuracy < val_stats[task_key]: + max_accuracy = val_stats[task_key] + if args.output_dir and args.save_ckpt: + utils.save_model( + args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) + + print(f'Max performance: {max_accuracy:.5f}%') + if log_writer is not None: + log_writer.update(acc=val_stats[task_key], head="perf", step=epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'val_{k}': v for k, v in val_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if args.output_dir and utils.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + opts, ds_init = get_args() + if opts.output_dir: + Path(opts.output_dir).mkdir(parents=True, exist_ok=True) + main(opts, ds_init) diff --git a/examples/longvit/utils.py b/examples/longvit/utils.py new file mode 100644 index 00000000..aa529ecf --- /dev/null +++ b/examples/longvit/utils.py @@ -0,0 +1,717 @@ +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +import datetime +import io +import os +import math +import time +import json +import argparse +import numpy as np +from pathlib import Path +from collections import defaultdict, deque +from timm.utils import get_state_dict + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch._six import inf +# from torchmetrics import Metric +from tensorboardX import SummaryWriter + + +def bool_flag(s): + """ + Parse boolean arguments from the command line. + """ + FALSY_STRINGS = {"off", "false", "0"} + TRUTHY_STRINGS = {"on", "true", "1"} + if s.lower() in FALSY_STRINGS: + return False + elif s.lower() in TRUTHY_STRINGS: + return True + else: + raise argparse.ArgumentTypeError("invalid value for a boolean flag") + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +class TensorboardLogger(object): + def __init__(self, log_dir): + self.writer = SummaryWriter(logdir=log_dir) + self.step = 0 + + def set_step(self, step=None): + if step is not None: + self.step = step + else: + self.step += 1 + + def update(self, head='scalar', step=None, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) + + def flush(self): + self.writer.flush() + + +def _load_checkpoint_for_ema(model_ema, checkpoint): + """ + Workaround for ModelEma._load_checkpoint to accept an already-loaded object + """ + mem_file = io.BytesIO() + torch.save(checkpoint, mem_file) + mem_file.seek(0) + model_ema._load_checkpoint(mem_file) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def _get_rank_env(): + if "RANK" in os.environ: + return int(os.environ["RANK"]) + else: + return int(os.environ['OMPI_COMM_WORLD_RANK']) + + +def _get_local_rank_env(): + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + else: + return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + + +def _get_world_size_env(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + else: + return int(os.environ['OMPI_COMM_WORLD_SIZE']) + + +# The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git) +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = _get_rank_env() + args.world_size = _get_world_size_env() # int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = _get_local_rank_env() + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_IP'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank, + timeout=datetime.timedelta(0, 7200) + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix=prefix) + + warn_missing_keys = [] + ignore_missing_keys = [] + for key in missing_keys: + keep_flag = True + for ignore_key in ignore_missing.split('|'): + if ignore_key in key: + keep_flag = False + break + if keep_flag: + warn_missing_keys.append(key) + else: + ignore_missing_keys.append(key) + + missing_keys = warn_missing_keys + + if len(missing_keys) > 0: + print("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(ignore_missing_keys) > 0: + print("Ignored weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, ignore_missing_keys)) + if len(error_msgs) > 0: + print('\n'.join(error_msgs)) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, + start_warmup_value=0, warmup_steps=-1, sched_type="cos"): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_steps > 0: + warmup_iters = warmup_steps + print("Set warmup steps = %d" % warmup_iters) + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + if sched_type == "cos": + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = np.array([ + final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) + elif sched_type == "linear": + schedule = np.linspace(base_value, final_value, epochs * niter_per_ep - warmup_iters) + else: + raise NotImplementedError() + + schedule = np.concatenate((warmup_schedule, schedule)) + + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): + output_dir = Path(args.output_dir) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch)] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + if model_ema is not None: + to_save['model_ema'] = get_state_dict(model_ema) + + save_on_master(to_save, checkpoint_path) + else: + client_state = {'epoch': epoch, "args": args} + if model_ema is not None: + client_state['model_ema'] = get_state_dict(model_ema) + model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch, client_state=client_state) + + +def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): + output_dir = Path(args.output_dir) + if loss_scaler is not None: + # torch.amp + if args.auto_resume and len(args.resume) == 0: + import glob + all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) + latest_ckpt = -1 + for ckpt in all_checkpoints: + t = ckpt.split('-')[-1].split('.')[0] + if t.isdigit(): + latest_ckpt = max(int(t), latest_ckpt) + if latest_ckpt >= 0: + args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) + print("Auto resume checkpoint: %s" % args.resume) + + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + if hasattr(args, 'model_ema') and args.model_ema: + _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + print("With optim & sched!") + else: + # deepspeed, only support '--auto_resume'. + if args.auto_resume: + import glob + all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) + latest_ckpt = -1 + for ckpt in all_checkpoints: + t = ckpt.split('-')[-1].split('.')[0] + if t.isdigit(): + latest_ckpt = max(int(t), latest_ckpt) + if latest_ckpt >= 0: + args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) + print("Auto resume checkpoint: %d" % latest_ckpt) + _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) + args.start_epoch = client_states['epoch'] + 1 + if model_ema is not None: + if args.model_ema: + _load_checkpoint_for_ema(model_ema, client_states['model_ema']) + + +# The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git) +def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix, is_eval=False): + if ckpt_path.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + ckpt_path, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(ckpt_path, map_location='cpu') + + print("Load ckpt from %s" % ckpt_path) + checkpoint_model = None + for model_key in model_key.split('|'): + if model_key in checkpoint: + checkpoint_model = checkpoint[model_key] + print("Load state_dict by model_key = %s" % model_key) + break + + if checkpoint_model is None: + checkpoint_model = checkpoint + + checkpoint_model = {k.replace("module.backbone.", "model.").replace("backbone.", "model."):v for k,v in checkpoint_model.items()} + + state_dict = model.state_dict() + for k in ['head.weight', 'head.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + # interpolate position embedding + for pos_embed_key in ("model.pos_embed",): + if pos_embed_key in checkpoint_model: + pos_embed_checkpoint = checkpoint_model[pos_embed_key] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.model.patch_embed.num_patches + num_extra_tokens = getattr(model.model, "pos_embed").shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens.float(), size=(new_size, new_size), mode='bicubic', align_corners=False).half() + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model[pos_embed_key] = new_pos_embed + + load_state_dict(model, checkpoint_model, prefix=model_prefix) + + +def create_ds_config(args): + args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") + with open(args.deepspeed_config, mode="w") as writer: + ds_config = { + "train_batch_size": args.batch_size * args.update_freq * get_world_size(), + "train_micro_batch_size_per_gpu": args.batch_size, + "steps_per_print": 1000, + "optimizer": { + "type": "Adam", + "adam_w_mode": True, + "params": { + "lr": args.lr, + "weight_decay": args.weight_decay, + "bias_correction": True, + "betas": [ + args.opt_betas[0], + args.opt_betas[1] + ], + "eps": args.opt_eps + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": getattr(args, "initial_scale_power", 12), + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "amp": { + "enabled": False, + "opt_level": "O2" + } + } + + if args.clip_grad is not None: + ds_config.update({'gradient_clipping': args.clip_grad}) + + if args.zero_stage == 1: + ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}}) + elif args.zero_stage > 1: + raise NotImplementedError() + + writer.write(json.dumps(ds_config, indent=2)) + + +def merge_batch_tensors_by_dict_key(batch): + batch_tensors = {} + for tensor_key in batch[0]: + if isinstance(batch[0][tensor_key], torch.Tensor): + batch_tensors[tensor_key] = torch.stack([d[tensor_key] for d in batch]) + elif 'label' in tensor_key: + batch_tensors[tensor_key] = torch.tensor([d[tensor_key] for d in batch], dtype=torch.long) + else: + batch_tensors[tensor_key] = torch.tensor([d[tensor_key] for d in batch], dtype=torch.float) + return batch_tensors + + +def get_loss_scale_for_deepspeed(model): + optimizer = model.optimizer + loss_scale = None + if hasattr(optimizer, 'loss_scale'): + loss_scale = optimizer.loss_scale + elif hasattr(optimizer, 'cur_scale'): + loss_scale = optimizer.cur_scale + return loss_scale + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(output, x) + return tuple(output) + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + dist.all_reduce(all_gradients) + return all_gradients[dist.get_rank()] + + +def gather_tcga_features( + tcga_features, +): + gathered_tcga_features = GatherLayer.apply(tcga_features) + all_tcga_features = torch.cat(gathered_tcga_features, dim=1) + + return all_tcga_features + + +def write_result_to_jsonl(test_stats, result_file): + with open(result_file, mode="w", encoding="utf-8") as writer: + writer.write(json.dumps(test_stats, indent=None)) + + +def read_result_from_jsonl(result_file): + with open(result_file, mode="r", encoding="utf-8") as reader: + return json.load(reader) + + +# divide continuous time scale into k discrete bins in total, T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)} +# Y = T_discrete is the discrete event time: +# Y = -1 if T_cont \in (-inf, 0), Y = 0 if T_cont \in [0, a_1), Y = 1 if T_cont in [a_1, a_2), ..., Y = k-1 if T_cont in [a_(k-1), inf) +# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X), t = -1,0,1,2,...,k +# S: survival function: P(Y > t | X) +# all patients are alive from (-inf, 0) by definition, so P(Y=-1) = 0 +# h(-1) = 0 ---> do not need to model +# S(-1) = P(Y > -1 | X) = 1 ----> do not need to model +''' +Summary: neural network is hazard probability function, h(t) for t = 0,1,2,...,k-1 +corresponding Y = 0,1, ..., k-1. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf] +''' +def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7): + batch_size = len(Y) + Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k + c = c.view(batch_size, 1).float() #censorship status, 0 or 1 + if S is None: + S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards + # without padding, S(0) = S[0], h(0) = h[0] + S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition + # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0] + #h[y] = h(1) + #S[1] = S(1) + uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps))) + censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps)) + neg_l = censored_loss + uncensored_loss + loss = (1-alpha) * neg_l + alpha * uncensored_loss + loss = loss.mean() + return loss + + +class NLLSurvLoss(object): + def __init__(self, alpha=0.15): + self.alpha = alpha + + def __call__(self, hazards, S, Y, c, alpha=None): + if alpha is None: + return nll_loss(hazards, S, Y, c, alpha=self.alpha) + else: + return nll_loss(hazards, S, Y, c, alpha=alpha) diff --git a/setup.py b/setup.py index 42e12e3c..117eb920 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ license="MIT", url="https://github.com/microsoft/torchscale", packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), - install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13"], + install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13", "einops"], python_requires=">=3.8.0", classifiers=[ "Programming Language :: Python :: 3", diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 5898f394..36fe0223 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -53,6 +53,24 @@ def __init__(self, **kwargs): self.ddp_rank = kwargs.pop("ddp_rank", 0) self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + # Dilated Attention + self.flash_attention = kwargs.pop("flash_attention", False) + self.segment_length = kwargs.pop("segment_length", None) + self.dilated_ratio = kwargs.pop("dilated_ratio", None) + self.seq_parallel = kwargs.pop("seq_parallel", False) + self.postprocessing() + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + self.postprocessing() + + def postprocessing(self): + if self.segment_length is not None and self.segment_length != '': + self.segment_length = eval(self.segment_length) + if self.dilated_ratio is not None and self.dilated_ratio != '': + self.dilated_ratio = eval(self.dilated_ratio) if self.deepnorm: self.encoder_normalize_before = False @@ -65,11 +83,6 @@ def __init__(self, **kwargs): self.moe_second_expert_policy = "random" assert self.moe_freq > 0 and self.moe_expert_count > 0 - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) - class DecoderConfig(object): def __init__(self, **kwargs): @@ -117,23 +130,36 @@ def __init__(self, **kwargs): self.ddp_rank = kwargs.pop("ddp_rank", 0) self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + # Dilated Attention + self.flash_attention = kwargs.pop("flash_attention", False) + self.segment_length = kwargs.pop("segment_length", None) + self.dilated_ratio = kwargs.pop("dilated_ratio", None) + self.seq_parallel = kwargs.pop("seq_parallel", False) + self.postprocessing() + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + self.postprocessing() + + def postprocessing(self): + if self.segment_length is not None and self.segment_length != '': + self.segment_length = eval(self.segment_length) + if self.dilated_ratio is not None and self.dilated_ratio != '': + self.dilated_ratio = eval(self.dilated_ratio) if self.deepnorm: - self.decoder_normalize_before = False + self.encoder_normalize_before = False self.subln = False if self.subln: - self.decoder_normalize_before = True + self.encoder_normalize_before = True self.deepnorm = False if self.use_xmoe: self.moe_normalize_gate_prob_before_dropping = True self.moe_second_expert_policy = "random" assert self.moe_freq > 0 and self.moe_expert_count > 0 - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) - class EncoderDecoderConfig(object): def __init__(self, **kwargs): @@ -189,26 +215,37 @@ def __init__(self, **kwargs): self.ddp_rank = kwargs.pop("ddp_rank", 0) self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + # Dilated Attention + self.flash_attention = kwargs.pop("flash_attention", False) + self.segment_length = kwargs.pop("segment_length", None) + self.dilated_ratio = kwargs.pop("dilated_ratio", None) + self.seq_parallel = kwargs.pop("seq_parallel", False) + self.postprocessing() + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + self.postprocessing() + + def postprocessing(self): + if self.segment_length is not None and self.segment_length != '': + self.segment_length = eval(self.segment_length) + if self.dilated_ratio is not None and self.dilated_ratio != '': + self.dilated_ratio = eval(self.dilated_ratio) if self.deepnorm: self.encoder_normalize_before = False - self.decoder_normalize_before = False self.subln = False if self.subln: self.encoder_normalize_before = True - self.decoder_normalize_before = True self.deepnorm = False if self.use_xmoe: self.moe_normalize_gate_prob_before_dropping = True self.moe_second_expert_policy = "random" assert self.moe_freq > 0 and self.moe_expert_count > 0 - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) - - + class RetNetConfig(object): def __init__(self, **kwargs): self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) @@ -257,19 +294,22 @@ def __init__(self, **kwargs): self.ddp_rank = kwargs.pop("ddp_rank", 0) self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + self.postprocessing() + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + self.postprocessing() + + def postprocessing(self): if self.deepnorm: - self.decoder_normalize_before = False + self.encoder_normalize_before = False self.subln = False if self.subln: - self.decoder_normalize_before = True + self.encoder_normalize_before = True self.deepnorm = False if self.use_xmoe: self.moe_normalize_gate_prob_before_dropping = True self.moe_second_expert_policy = "random" assert self.moe_freq > 0 and self.moe_expert_count > 0 - - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index ed407b06..4006b0cf 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -155,6 +155,7 @@ def forward( attn_mask=self_attn_mask, rel_pos=self_attn_rel_pos, is_first_step=is_first_step, + is_causal=True, ) x = self.dropout_module(x) @@ -430,13 +431,16 @@ def forward( for idx, layer in enumerate(self.layers): if incremental_state is None or is_first_step: - self_attn_mask = torch.triu( - torch.zeros([x.size(1), x.size(1)]) - .float() - .fill_(float("-inf")) - .type_as(x), - 1, - ) + if not self.args.flash_attention: + self_attn_mask = torch.triu( + torch.zeros([x.size(1), x.size(1)]) + .float() + .fill_(float("-inf")) + .type_as(x), + 1, + ) + else: + self_attn_mask = None if is_first_step and incremental_state is not None: if idx not in incremental_state: incremental_state[idx] = {} diff --git a/torchscale/component/dilated_attention.py b/torchscale/component/dilated_attention.py new file mode 100644 index 00000000..55f6d4ef --- /dev/null +++ b/torchscale/component/dilated_attention.py @@ -0,0 +1,217 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math + +import torch +import torch.nn.functional as F +from einops import rearrange + +from .multihead_attention import MultiheadAttention +from .utils import padding_to_multiple_of, all_gather_func, get_data_parallel_rank, get_data_parallel_world_size + + +class DilatedAttention(MultiheadAttention): + + def dense_to_sparse(self, x, ratio): + length = x.size(1) + padding = padding_to_multiple_of(length, ratio) + head_padding = padding_to_multiple_of(self.num_heads, ratio) + + if padding > 0 or head_padding > 0: + x = F.pad(x, (0, 0, 0, head_padding, 0, padding), value = 0.) + + x = rearrange(x, 'b (l r1) (r2 h) d -> b l h d r1 r2', r1=ratio, r2=ratio) + x = torch.diagonal(x, offset=0, dim1=4, dim2=5) + x = rearrange(x, 'b l h d r -> b l (r h) d') + + if head_padding > 0: + x = x[:, :, :self.num_heads] + + return x + + def sparse_to_dense(self, out, lse, ratio): + head_padding = padding_to_multiple_of(self.num_heads, ratio) + + if head_padding > 0: + out = F.pad(out, (0, 0, 0, head_padding), value = 0.) + lse = F.pad(lse, (0, 0, 0, head_padding), value = -1e8) + + out = rearrange(out, 'b l (r h) d -> b l h d r', r=ratio) + out = torch.diag_embed(out, offset=0, dim1=4, dim2=5) + out = rearrange(out, 'b l h d r1 r2 -> b (r2 h) (l r1) d', r1=ratio, r2=ratio) + + lse = rearrange(lse, 'b (r h) l -> b l h r', r=ratio) + lse = torch.diag_embed(lse, offset=0, dim1=3, dim2=4) + lse = lse.masked_fill_(lse==0, -1e8) + lse = rearrange(lse, 'b l h r1 r2 -> b (r2 h) (l r1) 1', r1=ratio, r2=ratio) + + if head_padding > 0: + out = out[:, :self.num_heads] + lse = lse[:, :self.num_heads] + + return out, lse + + def gather_kv(self, x, sl, seq_len, is_causal=True): + bsz = x.size(0) + assert sl % seq_len == 0 + num_rank_per_segment = sl // seq_len + + x = all_gather_func(x) + current_rank = get_data_parallel_rank() + x = rearrange(x, '(w b) l h d -> w b l h d', b=bsz) + + if is_causal: + if current_rank > 0: + x = x[:current_rank] + else: + x = x[:1] * 0 + + current_segment = current_rank // num_rank_per_segment * num_rank_per_segment + x = x[current_segment:current_segment+num_rank_per_segment] + + x = rearrange(x, 'w b l h d -> b (w l) h d') + return x + + def gathering(self, x, dr, sl, is_causal=True, offset=0, is_kv=False, seq_parall=True): + + curr_x = x + if offset > 0: + curr_x = F.pad(curr_x, (0, 0, 0, 0, offset % sl, 0), value=0.) + seq_len = curr_x.size(1) + should_gather_kv = is_kv and (get_data_parallel_world_size() > 1) and (sl > seq_len) and seq_parall + _sl = sl + sl = min(sl, seq_len) + padding = padding_to_multiple_of(seq_len, sl) + + if padding > 0: + curr_x = F.pad(curr_x, (0, 0, 0, 0, 0, padding), value = 0.) + + curr_x = rearrange(curr_x, 'b (n g) h d -> (b n) g h d', g=sl) + curr_x = self.dense_to_sparse(curr_x, dr) + + if should_gather_kv: + curr_x = self.gather_kv(curr_x, _sl, seq_len, is_causal) + + curr_x = rearrange(curr_x, 'b l h d -> (b h) l d') + + return curr_x + + def scattering(self, outs, lses, seq_len, bsz, offset=0): + assert len(outs) == len(lses) + assert len(outs) % len(self.args.dilated_ratio) == 0 + all_outs, all_lses = [], [] + drs = self.args.dilated_ratio + if len(outs) > len(drs): + drs = drs * (len(outs) // len(drs)) + + for dr, o, lse in zip(drs, outs, lses): + o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads) + o, lse = self.sparse_to_dense(o, lse, dr) + o = rearrange(o, '(b n) h g d -> (b h) (n g) d', b=bsz) + lse = rearrange(lse, '(b n) h g 1 -> (b h) (n g) 1', b=bsz) + o = o[:, offset:offset+seq_len] + lse = lse[:, offset:offset+seq_len] + + all_outs.append(o) + all_lses.append(lse) + + with torch.no_grad(): + max_lse = torch.stack(all_lses, dim=0) + max_lse = max_lse.max(0)[0] + all_lses = [torch.exp(lse-max_lse) for lse in all_lses] + lse_sum = torch.stack(all_lses, dim=0).sum(0) + all_lses = [lse / lse_sum for lse in all_lses] + + out = 0 + for o, lse in zip(all_outs, all_lses): + out += o * lse.type_as(o) + out = rearrange(out, '(b h) l d -> b l (h d)', h=self.num_heads) + + return out + + def forward( + self, + query, + key, + value, + incremental_state=None, + key_padding_mask=None, + attn_mask=None, + rel_pos=None, + is_first_step=False, + is_causal=False, + ): + assert self.args.flash_attention + assert rel_pos is None + bsz, tgt_len, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" + + key_bsz, src_len, _ = key.size() + assert key_bsz == bsz, f"{query.size(), key.size()}" + assert value is not None + assert bsz, src_len == value.shape[:2] + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads) + k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads) + v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads) + + if incremental_state is not None and not is_first_step: + offset = src_len - 1 + else: + offset = 0 + + if incremental_state is not None: + if "prev_key" in incremental_state: + prev_key = incremental_state["prev_key"].view( + bsz * self.num_heads, -1, self.head_dim + ) + prev_value = incremental_state["prev_value"].view( + bsz * self.num_heads, -1, self.head_dim + ) + k = torch.cat([prev_key, k], dim=1) + v = torch.cat([prev_value, v], dim=1) + incremental_state["prev_key"] = k.view( + bsz, self.num_heads, -1, self.head_dim + ) + incremental_state["prev_value"] = v.view( + bsz, self.num_heads, -1, self.head_dim + ) + src_len = k.size(1) + + if self.xpos is not None: + if incremental_state is not None and not is_first_step: + offset = src_len - 1 + else: + offset = 0 + k = self.xpos(k, offset=0, downscale=True) + q = self.xpos(q, offset=offset, downscale=False) + + q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads) + k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads) + v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads) + + outs, lses = [], [] + for sl, dr in zip(self.args.segment_length, self.args.dilated_ratio): + ki = self.gathering(k, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel) + vi = self.gathering(v, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel) + qi = self.gathering(q, dr, sl, is_causal=is_causal, offset=offset, is_kv=False, seq_parall=self.args.seq_parallel) + + out, lse = self.attention_ops(qi, ki, vi, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal) + + outs.append(out) + lses.append(lse) + + attn = self.scattering(outs, lses, tgt_len, bsz, offset=offset) + + if self.inner_attn_ln is not None: + attn = self.inner_attn_ln(attn) + + attn = self.out_proj(attn) + + return attn, None diff --git a/torchscale/component/flash_attention.py b/torchscale/component/flash_attention.py new file mode 100644 index 00000000..ac549344 --- /dev/null +++ b/torchscale/component/flash_attention.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + + +from typing import Any, Optional +import torch + +if torch.cuda.is_available(): + try: + if torch.cuda.get_device_capability()[0] > 7: + from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func + + def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): + assert bias is None + attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True) + return attn, lse + + else: + from xformers.ops.fmha import ( + cutlass, + Inputs, + Context, + _memory_efficient_attention_forward_requires_grad, + _memory_efficient_attention_backward, + LowerTriangularMask, + ) + + class FlashAttnFunc(torch.autograd.Function): + @staticmethod + # type: ignore + def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False): + if is_causal: + assert bias is None + attn_bias = LowerTriangularMask() + else: + attn_bias = bias + + inp = Inputs( + query=q, + key=k, + value=v, + attn_bias=attn_bias, + p=dropout, + scale=softmax_scale, + ) + op_fw = cutlass.FwOp + op_bw = cutlass.BwOp + + out, op_ctx = _memory_efficient_attention_forward_requires_grad( + inp=inp, op=op_fw + ) + + # Saving attn_bias is a bit complicated, as the + # torch part should go in `save_for_backward` + if isinstance(inp.attn_bias, torch.Tensor): + attn_bias_tensor = inp.attn_bias + attn_bias_ctx = None + else: + attn_bias_tensor = None + attn_bias_ctx = inp.attn_bias + + ctx.save_for_backward( + inp.query, + inp.key, + inp.value, + op_ctx.out, + op_ctx.lse, + ) + ctx.rng_state = op_ctx.rng_state + ctx.attn_bias_tensor = attn_bias_tensor + if op_ctx.op_bw is not None: + if op_bw is not None and op_bw is not op_ctx.op_bw: + raise ValueError( + f"Specified op_bw={op_bw.NAME}, but forward op " + f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None." + ) + op_bw = op_ctx.op_bw + ctx.op_fw = op_fw + ctx.op_bw = op_bw + ctx.p = inp.p + + ctx.scale = inp.scale + ctx.attn_bias_ctx = attn_bias_ctx + return out, op_ctx.lse + + @staticmethod + def deserialize_bias( + attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor] + ) -> Any: + if attn_bias_tensor is None: + return attn_bias_ctx + return attn_bias_tensor + + @classmethod + @torch.autograd.function.once_differentiable + def backward(cls, ctx, grad, dlse): + # Re-create context + query, key, value, out, lse = ctx.saved_tensors + attn_bias_tensor = ctx.attn_bias_tensor + rng_state = ctx.rng_state + inp = Inputs( + query=query, + key=key, + value=value, + attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor), + p=ctx.p, + scale=ctx.scale, + ) + op_ctx = Context( + lse=lse, + out=out, + rng_state=rng_state, + ) + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw + ) + return grads.dq, grads.dk, grads.dv, None, grads.db, None, None + + flash_attn_func = FlashAttnFunc.apply + except ModuleNotFoundError: + flash_attn_func = None +else: + flash_attn_func = None diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index 33e917e7..6017bc8d 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F from torch import nn +from einops import rearrange try: from apex.normalization import FusedLayerNorm as LayerNorm except ModuleNotFoundError: @@ -13,6 +14,7 @@ from .multiway_network import MultiwayWrapper from .xpos_relative_position import XPOS +from .flash_attention import flash_attn_func class MultiheadAttention(nn.Module): @@ -32,6 +34,7 @@ def __init__( self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim**-0.5 + self.dropout = dropout self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention @@ -62,6 +65,47 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.out_proj.weight) nn.init.constant_(self.out_proj.bias, 0.0) + def attention_ops(self, q, k, v, key_padding_mask=None, attn_mask=None, rel_pos=None, is_causal=False): + if not self.args.flash_attention: + q *= self.scaling + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + if attn_mask is not None: + attn_weights = torch.nan_to_num(attn_weights) + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + attn_weights = rearrange(attn_weights, '(b h) t s -> b h t s', h=self.num_heads) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + attn_weights = rearrange(attn_weights, 'b h t s -> (b h) t s') + + if rel_pos is not None: + rel_pos = rel_pos.view(attn_weights.size()) + attn_weights = attn_weights + rel_pos + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + attn_probs = self.dropout_module(attn_weights) + + attn = torch.bmm(attn_probs, v) + attn = rearrange(attn, '(b h) l d -> b l (h d)', h=self.num_heads) + else: + assert flash_attn_func is not None + assert rel_pos is None + q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads) + k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads) + v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads) + attn, lse = flash_attn_func(q, k, v, self.dropout, attn_mask, None, is_causal) + attn = rearrange(attn, 'b l h d -> b l (h d)') + attn_weights = lse[:, :, :attn.size(1)] + + return attn, attn_weights + def forward( self, query, @@ -72,6 +116,7 @@ def forward( attn_mask=None, rel_pos=None, is_first_step=False, + is_causal=False, ): bsz, tgt_len, embed_dim = query.size() src_len = tgt_len @@ -85,14 +130,10 @@ def forward( q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) - q *= self.scaling - q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) - v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) - q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) - k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) - v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) + q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads) + k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads) + v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads) if incremental_state is not None: if "prev_key" in incremental_state: @@ -120,39 +161,11 @@ def forward( k = self.xpos(k, offset=0, downscale=True) q = self.xpos(q, offset=offset, downscale=False) - attn_weights = torch.bmm(q, k.transpose(1, 2)) - - if attn_mask is not None: - attn_weights = torch.nan_to_num(attn_weights) - attn_mask = attn_mask.unsqueeze(0) - attn_weights += attn_mask - - if key_padding_mask is not None: - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - float("-inf"), - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if rel_pos is not None: - rel_pos = rel_pos.view(attn_weights.size()) - attn_weights = attn_weights + rel_pos - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights - ) - attn_probs = self.dropout_module(attn_weights) - - attn = torch.bmm(attn_probs, v) - attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) + attn, attn_weights = self.attention_ops(q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal) if self.inner_attn_ln is not None: attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) - attn_weights = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ).transpose(1, 0) return attn, attn_weights diff --git a/torchscale/component/utils.py b/torchscale/component/utils.py new file mode 100644 index 00000000..4c8a5ad4 --- /dev/null +++ b/torchscale/component/utils.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.distributed as dist + +def padding_to_multiple_of(n, mult): + remainder = n % mult + if remainder == 0: + return 0 + return mult - remainder + +def get_data_parallel_group(): + if torch.distributed.is_initialized(): + if not hasattr(get_data_parallel_group, "_global_group"): + get_data_parallel_group._global_group = dist.new_group() + return get_data_parallel_group._global_group + else: + return None + +def get_rank(group): + return dist.get_rank(group=group) + +def get_world_size(group): + if torch.distributed.is_initialized(): + return dist.get_world_size(group=group) + else: + return 1 + +def get_data_parallel_rank(): + return get_rank(get_data_parallel_group()) + +def get_data_parallel_world_size(): + return get_world_size(get_data_parallel_group()) + + +class Allgather(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_): + world_size = get_data_parallel_world_size() + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base(output, input_.contiguous(), + group=get_data_parallel_group()) + + return output + + @staticmethod + def backward(ctx, grad_output): + world_size = get_data_parallel_world_size() + + dim_size = list(grad_output.size()) + assert dim_size[0] % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=grad_output.dtype, + device=torch.cuda.current_device()) + + torch.distributed._reduce_scatter_base(output, grad_output.contiguous(), + group=get_data_parallel_group()) + + return output + +all_gather_func = Allgather.apply diff --git a/torchscale/component/xmoe/global_groups.py b/torchscale/component/xmoe/global_groups.py index 3ee57520..c6c31096 100644 --- a/torchscale/component/xmoe/global_groups.py +++ b/torchscale/component/xmoe/global_groups.py @@ -59,7 +59,3 @@ def get_all2all_group(moe_expert_count): my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) return get_all2all_group._all2all_groups[my_group_idx] - - - - diff --git a/torchscale/model/LongNet.py b/torchscale/model/LongNet.py new file mode 100644 index 00000000..f6de261d --- /dev/null +++ b/torchscale/model/LongNet.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +from torchscale.architecture.decoder import Decoder, DecoderLayer +from torchscale.architecture.encoder import Encoder, EncoderLayer +from torchscale.component.dilated_attention import DilatedAttention +from fairscale.nn import checkpoint_wrapper, wrap + + +class LongNetDecoderLayer(DecoderLayer): + + def build_self_attention(self, embed_dim, args): + return DilatedAttention( + args, + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + encoder_decoder_attention=False, + subln=args.subln, + ) + +class LongNetDecoder(Decoder): + + def build_decoder_layer( + self, args, depth, is_moe_layer=False, is_encoder_decoder=False + ): + layer = LongNetDecoderLayer( + args, + depth, + is_moe_layer=is_moe_layer, + is_encoder_decoder=is_encoder_decoder, + ) + if args.checkpoint_activations: + layer = checkpoint_wrapper(layer) + if args.fsdp: + layer = wrap(layer) + return layer + +class LongNetEncoderLayer(EncoderLayer): + + def build_self_attention(self, embed_dim, args): + return DilatedAttention( + args, + embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + encoder_decoder_attention=False, + subln=args.subln, + ) + +class LongNetEncoder(Encoder): + + def build_encoder_layer( + self, args, depth, is_moe_layer=False, is_encoder_decoder=False + ): + layer = LongNetEncoderLayer( + args, + depth, + is_moe_layer=is_moe_layer, + is_encoder_decoder=is_encoder_decoder, + ) + if args.checkpoint_activations: + layer = checkpoint_wrapper(layer) + if args.fsdp: + layer = wrap(layer) + return layer