## requirements
### mindspore==2.3.1
### mindnlp==0.4.1

导入所需库

In [1]:
import time
from tqdm import tqdm

import mindspore
import mindspore.numpy as np
from mindspore.dataset import GeneratorDataset
from mindspore import save_checkpoint

from mindnlp.transformers import AutoProcessor, BlipForConditionalGeneration
from mindnlp.core.optim import Adam
from mindnlp.core import value_and_grad

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm
Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 1.383 seconds.
Prefix dict has been built successfully.


数据集加载

In [2]:
class ImageCaptioningDataset():
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if not isinstance(idx, int):
            idx = int(idx)
        item = self.dataset[idx]
        encoding = self.processor(images=item['image'], text=item['text'], padding="max_length")
        return np.asarray(encoding["pixel_values"]), np.asarray(encoding["input_ids"]), np.asarray(encoding["attention_mask"])

def get_loader(dataset, processor, batch_size, shuffle=True, num_workers=1, drop_remainder=True):
    dataset = ImageCaptioningDataset(dataset, processor)
    return GeneratorDataset(source=dataset, 
                            column_names=["pixel_values", "input_ids", "attention_mask"],
                            shuffle=shuffle,
                            num_parallel_workers=num_workers
                           ).batch(batch_size=batch_size, 
                                   drop_remainder=drop_remainder)

自定义Trainer类

In [None]:
class Trainer:
    def __init__(self, net, optimizer, args,
                 train_dataset, eval_dataset=None
                 ):
        self.net = net
        self.opt = optimizer
        self.args = args
        self.train_dataset = train_dataset
        self.weights = self.net.trainable_params()
        self.value_and_grad = value_and_grad(fn=self.forward_fn, params_or_argnums=self.weights)
        self.run_eval = eval_dataset is not None
        if self.run_eval:
            self.eval_dataset = eval_dataset

    def forward_fn(self, input_ids, pixel_values, attention_mask):
        outputs = self.net(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        return loss

    def train_single(self, input_ids, pixel_values, attention_mask):
        self.opt.zero_grad()
        loss = self.value_and_grad(input_ids, pixel_values, attention_mask)
        self.opt.step()
        return loss

    def train(self, epochs):
        best_val_loss = float('inf')
        for epoch in range(0, epochs):
            print("\nEpoch {}/{}".format(epoch+1, epochs))
            self.net.set_train(True)
            tloss = 0
            step = 0
            for batch in tqdm(self.train_dataset.create_dict_iterator()):
                input_ids = batch["input_ids"]
                pixel_values = batch["pixel_values"].squeeze(1)
                attention_mask = batch["attention_mask"]

                loss = self.train_single(input_ids, pixel_values, attention_mask)

                tloss = tloss + loss.asnumpy()
                step = step + 1

            tloss /= step
            print("\tTrain Loss {:.04f}".format(tloss))

            if self.run_eval:
                self.net.set_train(False)
                val_loss = self.val()
                print("Epoch {} complete! Validation Loss : {}".format(epoch + 1, val_loss))
                if val_loss < best_val_loss:
                    print("Best validation Loss improved from {} to {}".format(best_val_loss, val_loss))
                    best_val_loss = val_loss
                    if self.args.save_path is not None:
                        print("saving model...")
                        save_checkpoint(self.net, self.args.save_path + 'best_model.ckpt')

    def val(self):
        vloss = 0
        step = 0
        with mindspore._no_grad():
            for batch in tqdm(self.eval_dataset.create_dict_iterator()):
                input_ids = batch["input_ids"]
                pixel_values = batch["pixel_values"].squeeze(1)
                attention_mask = batch["attention_mask"]

                outputs = self.net(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=input_ids)
                loss = outputs.loss

                vloss = vloss + loss.asnumpy()
                step = step + 1

        return vloss / step

主函数入口，完整训练流程

In [4]:
def main(args):    
    #load the blip model
    print("Building model! (This might take time if you are running this for first time)")
    st = time.time()
    mindspore.set_context(device_target=args.device_target, device_id=args.device_id, pynative_synchronize=True)
    processor = AutoProcessor.from_pretrained(args.model_name_or_path)
    model = BlipForConditionalGeneration.from_pretrained(args.model_name_or_path)
    print("Done in {} seconds".format(time.time() - st))

    print("Creating optimizer objects")
    st = time.time()
    optimizer = Adam(model.trainable_params(), lr=5e-5)
    print("Done in {} seconds".format(time.time() - st))

    #Creating dataloaders
    print("Creating train and val dataloaders")
    st = time.time()
    data = load_dataset(args.dataset_name_or_path)
    train_loader = get_loader(data['train'], processor, args.batch_size, shuffle=True, drop_remainder=True)
    val_loader = get_loader(data['test'], processor, args.batch_size, shuffle=True, drop_remainder=False)
    print("Done in {} seconds".format(time.time() - st))

    print("Let the training begin")
    st = time.time()
    trainer = Trainer(net=model, optimizer=optimizer, args=args, train_dataset=train_loader, eval_dataset=val_loader)
    trainer.train(epochs=args.max_eps)
    print("Done in {} seconds".format(time.time() - st))

设置训练参数，开始训练

In [5]:
from types import SimpleNamespace

args = SimpleNamespace()
args.device_target = 'Ascend'
args.device_id = 0
args.model_name_or_path = 'Salesforce/blip-image-captioning-base'
args.dataset_name_or_path = 'eeshclusive/captionary-dataset'
args.batch_size = 4
args.max_eps = 20
args.save_path = None

main(args)

Building model! (This might take time if you are running this for first time)




[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB


BlipTextLMHeadModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Done in 17.581424474716187 seconds
Creating optimizer objects
Done in 0.0065310001373291016 seconds
Creating train and val dataloaders


Generating train split: 100%|██████████| 162/162 [00:00<00:00, 440.00 examples/s]
Generating test split: 100%|██████████| 51/51 [00:00<00:00, 728.11 examples/s]


Done in 15.54231882095337 seconds
Let the training begin

Epoch 1/20


0it [00:00, ?it/s]

-

40it [01:41,  2.54s/it]


	Train Loss 7.3443


13it [00:04,  2.77it/s]


Epoch 1 complete! Validation Loss : 4.915086085979755
Best validation Loss improved from inf to 4.915086085979755

Epoch 2/20


40it [00:49,  1.24s/it]


	Train Loss 3.2319


13it [00:04,  2.89it/s]


Epoch 2 complete! Validation Loss : 1.8268253069657545
Best validation Loss improved from 4.915086085979755 to 1.8268253069657545

Epoch 3/20


40it [00:48,  1.22s/it]


	Train Loss 1.1534


13it [00:04,  2.81it/s]


Epoch 3 complete! Validation Loss : 0.5436725112108084
Best validation Loss improved from 1.8268253069657545 to 0.5436725112108084

Epoch 4/20


40it [00:48,  1.21s/it]


	Train Loss 0.3363


13it [00:04,  2.92it/s]


Epoch 4 complete! Validation Loss : 0.20180132755866417
Best validation Loss improved from 0.5436725112108084 to 0.20180132755866417

Epoch 5/20


40it [00:52,  1.31s/it]


	Train Loss 0.1522


13it [00:04,  2.79it/s]


Epoch 5 complete! Validation Loss : 0.1140028633750402
Best validation Loss improved from 0.20180132755866417 to 0.1140028633750402

Epoch 6/20


40it [00:50,  1.27s/it]


	Train Loss 0.0940


13it [00:04,  2.75it/s]


Epoch 6 complete! Validation Loss : 0.07747195661067963
Best validation Loss improved from 0.1140028633750402 to 0.07747195661067963

Epoch 7/20


40it [00:50,  1.26s/it]


	Train Loss 0.0668


13it [00:04,  2.74it/s]


Epoch 7 complete! Validation Loss : 0.05752776018702067
Best validation Loss improved from 0.07747195661067963 to 0.05752776018702067

Epoch 8/20


40it [00:51,  1.29s/it]


	Train Loss 0.0514


13it [00:04,  2.92it/s]


Epoch 8 complete! Validation Loss : 0.045433574284498505
Best validation Loss improved from 0.05752776018702067 to 0.045433574284498505

Epoch 9/20


40it [00:50,  1.27s/it]


	Train Loss 0.0413


13it [00:04,  2.77it/s]


Epoch 9 complete! Validation Loss : 0.03752241713496355
Best validation Loss improved from 0.045433574284498505 to 0.03752241713496355

Epoch 10/20


40it [00:50,  1.25s/it]


	Train Loss 0.0345


13it [00:04,  2.94it/s]


Epoch 10 complete! Validation Loss : 0.03150226190113104
Best validation Loss improved from 0.03752241713496355 to 0.03150226190113104

Epoch 11/20


40it [00:49,  1.24s/it]


	Train Loss 0.0294


13it [00:04,  2.91it/s]


Epoch 11 complete! Validation Loss : 0.027369202186281864
Best validation Loss improved from 0.03150226190113104 to 0.027369202186281864

Epoch 12/20


40it [00:49,  1.23s/it]


	Train Loss 0.0258


13it [00:04,  2.65it/s]


Epoch 12 complete! Validation Loss : 0.024082990936361827
Best validation Loss improved from 0.027369202186281864 to 0.024082990936361827

Epoch 13/20


40it [00:48,  1.21s/it]


	Train Loss 0.0230


13it [00:04,  2.76it/s]


Epoch 13 complete! Validation Loss : 0.021563996345951006
Best validation Loss improved from 0.024082990936361827 to 0.021563996345951006

Epoch 14/20


40it [00:50,  1.26s/it]


	Train Loss 0.0206


13it [00:04,  2.79it/s]


Epoch 14 complete! Validation Loss : 0.019490097291194476
Best validation Loss improved from 0.021563996345951006 to 0.019490097291194476

Epoch 15/20


40it [00:50,  1.26s/it]


	Train Loss 0.0188


13it [00:04,  2.95it/s]


Epoch 15 complete! Validation Loss : 0.018077760504988525
Best validation Loss improved from 0.019490097291194476 to 0.018077760504988525

Epoch 16/20


40it [00:48,  1.22s/it]


	Train Loss 0.0172


13it [00:04,  2.78it/s]


Epoch 16 complete! Validation Loss : 0.01667449616182309
Best validation Loss improved from 0.018077760504988525 to 0.01667449616182309

Epoch 17/20


40it [00:48,  1.21s/it]


	Train Loss 0.0160


13it [00:04,  2.77it/s]


Epoch 17 complete! Validation Loss : 0.015317266162198324
Best validation Loss improved from 0.01667449616182309 to 0.015317266162198324

Epoch 18/20


40it [00:48,  1.21s/it]


	Train Loss 0.0149


13it [00:04,  2.72it/s]


Epoch 18 complete! Validation Loss : 0.014371497556567192
Best validation Loss improved from 0.015317266162198324 to 0.014371497556567192

Epoch 19/20


40it [00:49,  1.24s/it]


	Train Loss 0.0139


13it [00:04,  2.84it/s]


Epoch 19 complete! Validation Loss : 0.013473815069748806
Best validation Loss improved from 0.014371497556567192 to 0.013473815069748806

Epoch 20/20


40it [00:47,  1.19s/it]


	Train Loss 0.0132


13it [00:04,  2.86it/s]

Epoch 20 complete! Validation Loss : 0.012598874477239756
Best validation Loss improved from 0.013473815069748806 to 0.012598874477239756
Done in 1139.0716316699982 seconds



