From 78ab058fd2017a5371854b9781c01aec92e796fc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Sep 2025 18:11:05 +0100 Subject: [PATCH 01/27] Let's just commit 3k loc in a single commit --- .../_core/BetaCudaDeviceInterface.cpp | 563 +++++++ .../_core/BetaCudaDeviceInterface.h | 108 ++ src/torchcodec/_core/CMakeLists.txt | 25 +- src/torchcodec/_core/CudaDeviceInterface.cpp | 77 +- src/torchcodec/_core/DeviceInterface.cpp | 70 +- src/torchcodec/_core/DeviceInterface.h | 89 +- src/torchcodec/_core/FFMPEGCommon.cpp | 8 + src/torchcodec/_core/FFMPEGCommon.h | 4 + src/torchcodec/_core/NVDECCache.cpp | 70 + src/torchcodec/_core/NVDECCache.h | 102 ++ src/torchcodec/_core/SingleStreamDecoder.cpp | 62 +- src/torchcodec/_core/SingleStreamDecoder.h | 1 + src/torchcodec/_core/StreamOptions.h | 3 + src/torchcodec/_core/custom_ops.cpp | 20 +- .../_core/nvcuvid_include/cuviddec.h | 1374 +++++++++++++++++ .../_core/nvcuvid_include/nvcuvid.h | 610 ++++++++ src/torchcodec/_core/ops.py | 2 + src/torchcodec/decoders/_video_decoder.py | 19 + test/test_decoders.py | 59 + test/utils.py | 10 + 20 files changed, 3204 insertions(+), 72 deletions(-) create mode 100644 src/torchcodec/_core/BetaCudaDeviceInterface.cpp create mode 100644 src/torchcodec/_core/BetaCudaDeviceInterface.h create mode 100644 src/torchcodec/_core/NVDECCache.cpp create mode 100644 src/torchcodec/_core/NVDECCache.h create mode 100644 src/torchcodec/_core/nvcuvid_include/cuviddec.h create mode 100644 src/torchcodec/_core/nvcuvid_include/nvcuvid.h diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp new file mode 100644 index 000000000..82b8d56b4 --- /dev/null +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -0,0 +1,563 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/torchcodec/_core/BetaCudaDeviceInterface.h" + +#include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/NVDECCache.h" + +#include // For cudaStreamSynchronize +#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" +#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +namespace { + +// Register the BETA CUDA interface with 'beta' variant +static bool g_cuda_beta = registerDeviceInterface( + DeviceInterfaceKey(torch::kCUDA, "beta"), + [](const torch::Device& device) { + return new BetaCudaDeviceInterface(device); + }); + +static int CUDAAPI +pfnSequenceCallback(void* pUserData, CUVIDEOFORMAT* videoFormat) { + BetaCudaDeviceInterface* decoder = + static_cast(pUserData); + return static_cast(decoder->streamPropertyChange(videoFormat)); +} + +static int CUDAAPI +pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* pPicParams) { + BetaCudaDeviceInterface* decoder = + static_cast(pUserData); + return decoder->frameReadyForDecoding(pPicParams); +} + +static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) { + // Check decoder capabilities - same checks as DALI + auto caps = CUVIDDECODECAPS{}; + caps.eCodecType = videoFormat->codec; + caps.eChromaFormat = videoFormat->chroma_format; + caps.nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8; + CUresult result = cuvidGetDecoderCaps(&caps); + TORCH_CHECK(result == CUDA_SUCCESS, "Failed to get decoder caps: ", result); + + TORCH_CHECK( + caps.bIsSupported, + "Codec configuration not supported on this GPU. " + "Codec: ", + static_cast(videoFormat->codec), + ", chroma format: ", + static_cast(videoFormat->chroma_format), + ", bit depth: ", + videoFormat->bit_depth_luma_minus8 + 8); + + TORCH_CHECK( + videoFormat->coded_width >= caps.nMinWidth && + videoFormat->coded_height >= caps.nMinHeight, + "Video is too small in at least one dimension. Provided: ", + videoFormat->coded_width, + "x", + videoFormat->coded_height, + " vs supported:", + caps.nMinWidth, + "x", + caps.nMinHeight); + + TORCH_CHECK( + videoFormat->coded_width <= caps.nMaxWidth && + videoFormat->coded_height <= caps.nMaxHeight, + "Video is too large in at least one dimension. Provided: ", + videoFormat->coded_width, + "x", + videoFormat->coded_height, + " vs supported:", + caps.nMaxWidth, + "x", + caps.nMaxHeight); + + TORCH_CHECK( + videoFormat->coded_width * videoFormat->coded_height / 256 <= + caps.nMaxMBCount, + "Video is too large (too many macroblocks). " + "Provided (width * height / 256): ", + videoFormat->coded_width * videoFormat->coded_height / 256, + " vs supported:", + caps.nMaxMBCount); + + // Decoder creation parameters, taken from DALI + CUVIDDECODECREATEINFO decoder_info = {}; + decoder_info.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8; + decoder_info.ChromaFormat = videoFormat->chroma_format; + decoder_info.CodecType = videoFormat->codec; + decoder_info.ulHeight = videoFormat->coded_height; + decoder_info.ulWidth = videoFormat->coded_width; + decoder_info.ulMaxHeight = videoFormat->coded_height; + decoder_info.ulMaxWidth = videoFormat->coded_width; + decoder_info.ulTargetHeight = + videoFormat->display_area.bottom - videoFormat->display_area.top; + decoder_info.ulTargetWidth = + videoFormat->display_area.right - videoFormat->display_area.left; + decoder_info.ulNumDecodeSurfaces = videoFormat->min_num_decode_surfaces; + decoder_info.ulNumOutputSurfaces = 2; + decoder_info.display_area.left = videoFormat->display_area.left; + decoder_info.display_area.right = videoFormat->display_area.right; + decoder_info.display_area.top = videoFormat->display_area.top; + decoder_info.display_area.bottom = videoFormat->display_area.bottom; + + CUvideodecoder rawDecoder; + result = cuvidCreateDecoder(&rawDecoder, &decoder_info); + TORCH_CHECK( + result == CUDA_SUCCESS, "Failed to create NVDEC decoder: ", result); + + return UniqueCUvideodecoder(rawDecoder, CUvideoDecoderDeleter{}); +} + +} // namespace + +BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) + : DeviceInterface(device) { + TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!"); + TORCH_CHECK( + device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + + // TODONVDEC P1: init size should probably be min_num_decode_surfaces from + // video format + frameBuffer_.resize(4); +} + +BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { + // TODONVDEC P0: we probably need to free the frames that have been decoded by + // NVDEC but not yet "mapped" - i.e. those that are still in frameBuffer_? + + if (decoder_) { + NVDECCache::GetCache(device_.index()) + .returnDecoder(&videoFormat_, std::move(decoder_)); + } + + if (videoParser_) { + // TODONVDEC P2: consider caching this? Does DALI do that? + cuvidDestroyVideoParser(videoParser_); + videoParser_ = nullptr; + } +} + +void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) { + TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); + timeBase_ = avStream->time_base; + + const AVCodecParameters* codecpar = avStream->codecpar; + TORCH_CHECK(codecpar != nullptr, "CodecParameters cannot be null"); + + TORCH_CHECK( + // TODONVDEC P0 support more + avStream->codecpar->codec_id == AV_CODEC_ID_H264, + "Can only do H264 for now"); + + // Setup bit stream filters (BSF): + // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html + // This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For + // now we apply BSF unconditionally, but it should be optional and dependent + // on codec and container. + const AVBitStreamFilter* avBSF = av_bsf_get_by_name("h264_mp4toannexb"); + TORCH_CHECK( + avBSF != nullptr, "Failed to find h264_mp4toannexb bitstream filter"); + + AVBSFContext* avBSFContext = nullptr; + int retVal = av_bsf_alloc(avBSF, &avBSFContext); + TORCH_CHECK( + retVal >= AVSUCCESS, + "Failed to allocate bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + bitstreamFilter_.reset(avBSFContext); + + retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecpar); + TORCH_CHECK( + retVal >= AVSUCCESS, + "Failed to copy codec parameters: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + retVal = av_bsf_init(bitstreamFilter_.get()); + TORCH_CHECK( + retVal == AVSUCCESS, + "Failed to initialize bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + // Create parser. Default values that aren't obvious are taken from DALI. + CUVIDPARSERPARAMS parserParams = {}; + parserParams.CodecType = cudaVideoCodec_H264; + parserParams.ulMaxNumDecodeSurfaces = 8; + parserParams.ulMaxDisplayDelay = 0; + // Callback setup, all are triggered by the parser within a call + // to cuvidParseVideoData + parserParams.pUserData = this; + parserParams.pfnSequenceCallback = pfnSequenceCallback; + parserParams.pfnDecodePicture = pfnDecodePictureCallback; + parserParams.pfnDisplayPicture = nullptr; + + CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams); + TORCH_CHECK( + result == CUDA_SUCCESS, "Failed to create video parser: ", result); +} + +// This callback is called by the parser within cuvidParseVideoData when there +// is a change in the stream's properties (like resolution change), as specified +// by CUVIDEOFORMAT. Particularly (but not just!), this is called at the very +// start of the stream. +// TODONVDEC P1: Code below mostly assume this is called only once at the start, +// we should handle the case of multiple calls. Probably need to flush buffers, +// etc. +unsigned char BetaCudaDeviceInterface::streamPropertyChange( + CUVIDEOFORMAT* videoFormat) { + TORCH_CHECK(videoFormat != nullptr, "Invalid video format"); + + videoFormat_ = *videoFormat; + + if (videoFormat_.min_num_decode_surfaces == 0) { + // Same as DALI's fallback + videoFormat_.min_num_decode_surfaces = 20; + } + + if (!decoder_) { + decoder_ = NVDECCache::GetCache(device_.index()).getDecoder(videoFormat); + + if (!decoder_) { + // TODONVDEC P0: consider re-configuring an existing decoder instead of + // re-creating one. See docs, see DALI. + decoder_ = createDecoder(videoFormat); + } + + TORCH_CHECK(decoder_, "Failed to get or create decoder"); + } + + // DALI also returns min_num_decode_surfaces from this function. This + // instructs the parser to reset its ulMaxNumDecodeSurfaces field to this + // value. + return videoFormat_.min_num_decode_surfaces; +} + +// Moral equivalent of avcodec_send_packet(). Here, we pass the AVPacket down to +// the NVCUVID parser. +int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) { + CUVIDSOURCEDATAPACKET cuvidPacket = {}; + + if (packet.get() && packet->data && packet->size > 0) { + // Regular packet with data + cuvidPacket.payload = packet->data; + cuvidPacket.payload_size = packet->size; + cuvidPacket.flags = CUVID_PKT_TIMESTAMP; + cuvidPacket.timestamp = packet->pts; + + // Like DALI: store packet PTS in queue to later assign to frames as they + // come out + packetsPtsQueue.push(packet->pts); + + } else { + // End of stream packet + cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM; + eofSent_ = true; + } + + CUresult result = cuvidParseVideoData(videoParser_, &cuvidPacket); + if (result != CUDA_SUCCESS) { + return AVERROR_EXTERNAL; + } + return AVSUCCESS; +} + +// TODONVDEC P0: cleanup this raw pointer / reference monstruosity. +ReferenceAVPacket* BetaCudaDeviceInterface::applyBSF( + ReferenceAVPacket& packet, + [[maybe_unused]] AutoAVPacket& filteredAutoPacket, + ReferenceAVPacket& filteredPacket) { + if (!bitstreamFilter_) { + return &packet; + } + int retVal = av_bsf_send_packet(bitstreamFilter_.get(), packet.get()); + TORCH_CHECK( + retVal >= AVSUCCESS, + "Failed to send packet to bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + retVal = av_bsf_receive_packet(bitstreamFilter_.get(), filteredPacket.get()); + TORCH_CHECK( + retVal >= AVSUCCESS, + "Failed to receive packet from bitstream filter: ", + getFFMPEGErrorStringFromErrorCode(retVal)); + + return &filteredPacket; +} + +// Parser triggers this callback within cuvidParseVideoData when a frame is +// ready to be decoded, i.e. the parser received all the necessary packets for a +// given frame. It means we can send that frame to be decoded by the hardware +// NVDEC decoder by calling cuvidDecodePicture which is non-blocking. +int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* pPicParams) { + if (isFlushing_) { + return 0; + } + + TORCH_CHECK(pPicParams != nullptr, "Invalid picture parameters"); + TORCH_CHECK(decoder_, "Decoder not initialized before picture decode"); + + // Send frame to be decoded by NVDEC - non-blocking call. + CUresult result = cuvidDecodePicture(decoder_.get(), pPicParams); + if (result != CUDA_SUCCESS) { + return 0; // Yes, you're reading that right, 0 mean error. + } + + // The frame was sent to be decoded on the NVDEC hardware. Now we store some + // relevant info into our frame buffer so that we can retrieve the decoded + // frame later when receiveFrame() is called. + // Importantly we need to 'guess' the PTS of that frame. The heuristic we use + // (like in DALI) is that the frames are ready to be decoded in the same order + // as the packets were sent to the parser. So we assign the PTS of the frame + // by popping the PTS of the oldest packet in our packetsPtsQueue (note: + // oldest doesn't necessarily mean lowest PTS!). + + TORCH_CHECK( + // TODONVDEC P0 the queue may be empty, handle that. + !packetsPtsQueue.empty(), + "PTS queue is empty when decoding a frame"); + int64_t guessedPts = packetsPtsQueue.front(); + packetsPtsQueue.pop(); + + // Field values taken from DALI + CUVIDPARSERDISPINFO dispInfo = {}; + dispInfo.picture_index = pPicParams->CurrPicIdx; + dispInfo.progressive_frame = !pPicParams->field_pic_flag; + dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1; + dispInfo.repeat_first_field = 0; + dispInfo.timestamp = guessedPts; + + FrameBufferSlot* slot = findEmptySlot(); + slot->dispInfo = dispInfo; + slot->guessedPts = guessedPts; + slot->occupied = true; + + return 1; +} + +// Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded +// frame with the exact desired PTS in our frame buffer. This logic is only +// valid in exact seek_mode, for now. +int BetaCudaDeviceInterface::receiveFrame( + UniqueAVFrame& avFrame, + int64_t desiredPts) { + FrameBufferSlot* slot = findFrameWithExactPts(desiredPts); + if (slot == nullptr) { + // No frame found, instruct caller to try again later after sending more + // packets. + return AVERROR(EAGAIN); + } + + slot->occupied = false; + slot->guessedPts = -1; + + CUVIDPROCPARAMS procParams = {}; + CUVIDPARSERDISPINFO dispInfo = slot->dispInfo; + procParams.progressive_frame = dispInfo.progressive_frame; + procParams.top_field_first = dispInfo.top_field_first; + procParams.unpaired_field = dispInfo.repeat_first_field < 0; + CUdeviceptr framePtr = 0; + unsigned int pitch = 0; + + // We know the frame we want was sent to the hardware decoder, but now we need + // to "map" it to an "output surface" before we can use its data. This is a + // blocking calls that waits until the frame is fully decoded and ready to be + // used. + CUresult result = cuvidMapVideoFrame( + static_cast(decoder_.get()), + dispInfo.picture_index, + &framePtr, + &pitch, + &procParams); + + if (result != CUDA_SUCCESS) { + return AVERROR_EXTERNAL; + } + + avFrame = convertCudaFrameToAVFrame(framePtr, pitch, dispInfo); + + // Unmap the frame so that the decoder can reuse its corresponding output + // surface. Whether this is blocking is unclear? + cuvidUnmapVideoFrame(static_cast(decoder_.get()), framePtr); + // TODONVDEC P0: Get clarity on this: + // We assume that the framePtr is still valid after unmapping. That framePtr + // is now part of the avFrame, which we'll return to the caller, and the + // caller will immediately use it for color-conversion, at which point a copy + // happens. After the copy, it doesn't matter whether framePtr is still valid. + // And we'll return to this function (and to cuvidUnmapVideoFrame()) *after* + // the copy is made, so there should be no risk of overwriting the data before + // the copy. + // Buuuut yeah, we need get more clarity on what actually happens, and on + // what's needed. IIUC DALI makes the color-conversion copy immediately after + // cuvidMapVideoFrame() and *before* cuvidUnmapVideoFrame() with a synchronize + // in between. So maybe we should do the same. + + return AVSUCCESS; +} + +UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame( + CUdeviceptr framePtr, + unsigned int pitch, + const CUVIDPARSERDISPINFO& dispInfo) { + TORCH_CHECK(framePtr != 0, "Invalid CUDA frame pointer"); + + // Get frame dimensions from video format display area (not coded dimensions) + // This matches DALI's approach and avoids padding issues + int width = videoFormat_.display_area.right - videoFormat_.display_area.left; + int height = videoFormat_.display_area.bottom - videoFormat_.display_area.top; + + TORCH_CHECK(width > 0 && height > 0, "Invalid frame dimensions"); + TORCH_CHECK( + pitch >= static_cast(width), "Pitch must be >= width"); + + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame.get() != nullptr, "Failed to allocate AVFrame"); + + avFrame->width = width; + avFrame->height = height; + avFrame->format = AV_PIX_FMT_CUDA; + avFrame->pts = dispInfo.timestamp; // == guessedPts + + unsigned int frameRateNum = videoFormat_.frame_rate.numerator; + unsigned int frameRateDen = videoFormat_.frame_rate.denominator; + int64_t duration = static_cast((frameRateDen * timeBase_.den)) / + (frameRateNum * timeBase_.num); + setDuration(avFrame, duration); + + // We need to assign the frame colorspace. This is crucial for proper color + // conversion. NVCUVID stores that in the matrix_coefficients field, but + // doesn't document the semantics of the values. Claude code generated this, + // which seems to work. Reassuringly, the values seem to match the + // corresponding indices in the FFmpeg enum for colorspace conversion + // (ff_yuv2rgb_coeffs): + // https://ffmpeg.org/doxygen/trunk/yuv2rgb_8c_source.html#l00047 + switch (videoFormat_.video_signal_description.matrix_coefficients) { + case 1: + avFrame->colorspace = AVCOL_SPC_BT709; + break; + case 6: + avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 + break; + default: + // Default to BT.601 + avFrame->colorspace = AVCOL_SPC_SMPTE170M; + break; + } + + avFrame->color_range = + videoFormat_.video_signal_description.video_full_range_flag + ? AVCOL_RANGE_JPEG + : AVCOL_RANGE_MPEG; + + // Below: Ask Claude. I'm not going to even pretend. + avFrame->data[0] = reinterpret_cast(framePtr); + avFrame->data[1] = reinterpret_cast(framePtr + (pitch * height)); + avFrame->data[2] = nullptr; + avFrame->data[3] = nullptr; + avFrame->linesize[0] = pitch; + avFrame->linesize[1] = pitch; + avFrame->linesize[2] = 0; + avFrame->linesize[3] = 0; + + return avFrame; +} + +void BetaCudaDeviceInterface::flush() { + isFlushing_ = true; + + // TODONVDEC P0: simplify flushing and "eofSent_" logic. We should just have a + // "sendEofPacket()" function that does the right thing, instead of setting + // CUVID_PKT_ENDOFSTREAM in different places. + if (!eofSent_) { + CUVIDSOURCEDATAPACKET cuvidPacket = {}; + cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM; + CUresult result = cuvidParseVideoData(videoParser_, &cuvidPacket); + if (result == CUDA_SUCCESS) { + eofSent_ = true; + } + } + + isFlushing_ = false; + + for (auto& slot : frameBuffer_) { + slot.occupied = false; + slot.guessedPts = -1; + } + + std::queue empty; + packetsPtsQueue.swap(empty); + + eofSent_ = false; +} + +void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_CUDA, + "Expected CUDA format frame from BETA CUDA interface"); + + // TODONVDEC P1: we use the 'default' cuda device interface for color + // conversion. That's a temporary hack to make things work. we should abstract + // the color conversion stuff separately. + if (!defaultCudaInterface_) { + auto cudaDevice = torch::Device(torch::kCUDA); + defaultCudaInterface_ = + std::unique_ptr(createDeviceInterface(cudaDevice)); + AVCodecContext dummyCodecContext = {}; + defaultCudaInterface_->initializeContext(&dummyCodecContext); + } + + defaultCudaInterface_->convertAVFrameToFrameOutput( + videoStreamOptions, + timeBase, + avFrame, + frameOutput, + preAllocatedOutputTensor); +} + +// TODONVDEC P0: Don't let buffer grow indefinitely. +BetaCudaDeviceInterface::FrameBufferSlot* +BetaCudaDeviceInterface::findEmptySlot() { + for (auto& slot : frameBuffer_) { + if (!slot.occupied) { + return &slot; + } + } + frameBuffer_.emplace_back(); + return &frameBuffer_.back(); +} + +BetaCudaDeviceInterface::FrameBufferSlot* +BetaCudaDeviceInterface::findFrameWithExactPts(int64_t desiredPts) { + for (auto& slot : frameBuffer_) { + if (slot.occupied && slot.guessedPts == desiredPts) { + return &slot; + } + } + return nullptr; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h new file mode 100644 index 000000000..632551e1c --- /dev/null +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -0,0 +1,108 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +// BETA CUDA device interface that provides direct control over NVDEC +// while keeping FFmpeg for demuxing. A lot of the logic, particularly the use +// of a cache for the decoders, is inspired by DALI's implementation which is +// APACHE 2.0: +// https://github.com/NVIDIA/DALI/blob/c7539676a24a8e9e99a6e8665e277363c5445259/dali/operators/video/frames_decoder_gpu.cc#L1 +// +// NVDEC / NVCUVID docs: +// https://docs.nvidia.com/video-technologies/video-codec-sdk/13.0/nvdec-video-decoder-api-prog-guide/index.html#using-nvidia-video-decoder-nvdecode-api + +#pragma once + +#include "src/torchcodec/_core/Cache.h" +#include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/NVDECCache.h" + +#include +#include +#include +#include +#include +#include + +#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" +#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" + +namespace facebook::torchcodec { + +class BetaCudaDeviceInterface : public DeviceInterface { + public: + explicit BetaCudaDeviceInterface(const torch::Device& device); + virtual ~BetaCudaDeviceInterface(); + + void initializeInterface(AVStream* stream) override; + + void convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + bool canDecodePacketDirectly() const override { + return true; + } + + int sendPacket(ReferenceAVPacket& packet) override; + int receiveFrame(UniqueAVFrame& avFrame, int64_t desiredPts) override; + void flush() override; + ReferenceAVPacket* applyBSF( + ReferenceAVPacket& packet, + AutoAVPacket& filteredAutoPacket, + ReferenceAVPacket& filteredPacket) override; + + // NVDEC callback functions (must be public for C callbacks) + unsigned char streamPropertyChange(CUVIDEOFORMAT* videoFormat); + int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams); + + private: + UniqueAVFrame convertCudaFrameToAVFrame( + CUdeviceptr framePtr, + unsigned int pitch, + const CUVIDPARSERDISPINFO& dispInfo); + + CUvideoparser videoParser_ = nullptr; + UniqueCUvideodecoder decoder_; + CUVIDEOFORMAT videoFormat_ = {}; + + struct FrameBufferSlot { + CUVIDPARSERDISPINFO dispInfo; + int64_t guessedPts; + bool occupied = false; + + FrameBufferSlot() : guessedPts(-1), occupied(false) { + memset(&dispInfo, 0, sizeof(dispInfo)); + } + }; + + std::vector frameBuffer_; + FrameBufferSlot* findEmptySlot(); + FrameBufferSlot* findFrameWithExactPts(int64_t desiredPts); + + std::queue packetsPtsQueue; + + bool eofSent_ = false; + + // Flush flag to prevent decode operations during flush (like DALI's + // isFlushing_) + bool isFlushing_ = false; + + AVRational timeBase_ = {0, 0}; + + UniqueAVBSFContext bitstreamFilter_; + + // Default CUDA interface for color conversion. + // TODONVDEC P2: we shouldn't need to keep a separate instance of the default. + // See other TODO there about how interfaces should be completely independent. + std::unique_ptr defaultCudaInterface_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index e3f9102e2..da96fc29d 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -98,7 +98,7 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp) endif() set(core_library_dependencies @@ -111,6 +111,29 @@ function(make_torchcodec_libraries ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) + + # Try to find NVCUVID. Try the normal way first. This should work locally. + find_library(NVCUVID_LIBRARY NAMES nvcuvid) + # If not found, try with version suffix, or hardcoded path. Appears + # to be necessary on the CI. + if(NOT NVCUVID_LIBRARY) + find_library(NVCUVID_LIBRARY NAMES nvcuvid.1 PATHS /usr/lib64 /usr/lib) + endif() + if(NOT NVCUVID_LIBRARY) + set(NVCUVID_LIBRARY "/usr/lib64/libnvcuvid.so.1") + endif() + + if(NVCUVID_LIBRARY) + message(STATUS "Found NVCUVID: ${NVCUVID_LIBRARY}") + else() + message(FATAL_ERROR "Could not find NVCUVID library") + endif() + + # Add CUDA Driver library (needed for cuCtxGetCurrent, etc.) + find_library(CUDA_DRIVER_LIBRARY NAMES cuda REQUIRED) + message(STATUS "Found CUDA Driver library: ${CUDA_DRIVER_LIBRARY}") + + list(APPEND core_library_dependencies ${NVCUVID_LIBRARY} ${CUDA_DRIVER_LIBRARY}) endif() make_torchcodec_sublibrary( diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 6a69d4fc3..3fc25bb76 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -13,6 +13,15 @@ extern "C" { #include } +// TODONVDEC P1 Changes were made to this file to accomodate for the BETA CUDA +// interface (see other TODONVDEC below). That's because the BETA CUDA interface +// relies on this default CUDA interface to do the color conversion. That's +// hacky, ugly, and leads to complicated code. We should refactor all this so +// that an interface doesn't need to know anything about any other interface. +// Note - this is more than just about the BETA CUDA interface: this default +// interface already relies on the CPU interface to do software decoding when +// needed, and that's already leading to similar complications. + namespace facebook::torchcodec { namespace { @@ -216,10 +225,11 @@ std::unique_ptr CudaDeviceInterface::initializeFiltersContext( return nullptr; } - TORCH_CHECK( - avFrame->hw_frames_ctx != nullptr, - "The AVFrame does not have a hw_frames_ctx. " - "That's unexpected, please report this to the TorchCodec repo."); + if (avFrame->hw_frames_ctx == nullptr) { + // TODONVDEC P2 return early for for beta interface where avFrames don't + // have a hw_frames_ctx. We should get rid of this or improve the logic. + return nullptr; + } auto hwFramesCtx = reinterpret_cast(avFrame->hw_frames_ctx->data); @@ -347,22 +357,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // Above we checked that the AVFrame was on GPU, but that's not enough, we // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), // because this is what the NPP color conversion routines expect. - TORCH_CHECK( - avFrame->hw_frames_ctx != nullptr, - "The AVFrame does not have a hw_frames_ctx. " - "That's unexpected, please report this to the TorchCodec repo."); - - auto hwFramesCtx = - reinterpret_cast(avFrame->hw_frames_ctx->data); - AVPixelFormat actualFormat = hwFramesCtx->sw_format; + // TODONVDEC P2 this can be hit from the beta interface, but there's no + // hw_frames_ctx in this case. We should try to understand how that affects + // this validation. + AVHWFramesContext* hwFramesCtx = nullptr; + if (avFrame->hw_frames_ctx != nullptr) { + hwFramesCtx = + reinterpret_cast(avFrame->hw_frames_ctx->data); + AVPixelFormat actualFormat = hwFramesCtx->sw_format; - TORCH_CHECK( - actualFormat == AV_PIX_FMT_NV12, - "The AVFrame is ", - (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) - : "unknown"), - ", but we expected AV_PIX_FMT_NV12. " - "That's unexpected, please report this to the TorchCodec repo."); + TORCH_CHECK( + actualFormat == AV_PIX_FMT_NV12, + "The AVFrame is ", + (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat) + : "unknown"), + ", but we expected AV_PIX_FMT_NV12. " + "That's unexpected, please report this to the TorchCodec repo."); + } auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); @@ -396,19 +407,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // arbitrary, but unfortunately we know it's hardcoded to be the default // stream by FFmpeg: // https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388 - TORCH_CHECK( - hwFramesCtx->device_ctx != nullptr, - "The AVFrame's hw_frames_ctx does not have a device_ctx. "); - auto cudaDeviceCtx = - static_cast(hwFramesCtx->device_ctx->hwctx); - at::cuda::CUDAEvent nvdecDoneEvent; - at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad. - c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex); - nvdecDoneEvent.record(nvdecStream); - - // Don't start NPP work before NVDEC is done decoding the frame! at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex); - nvdecDoneEvent.block(nppStream); + if (hwFramesCtx) { + // TODONVDEC P2 this block won't be hit from the beta interface because + // there is no hwFramesCtx, but we should still make sure there's no CUDA + // stream sync issue in the beta interface. + TORCH_CHECK( + hwFramesCtx->device_ctx != nullptr, + "The AVFrame's hw_frames_ctx does not have a device_ctx. "); + auto cudaDeviceCtx = + static_cast(hwFramesCtx->device_ctx->hwctx); + at::cuda::CUDAEvent nvdecDoneEvent; + at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad. + c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex); + nvdecDoneEvent.record(nvdecStream); + // Don't start NPP work before NVDEC is done decoding the frame! + nvdecDoneEvent.block(nppStream); + } // Create the NPP context if we haven't yet. nppCtx_->hStream = nppStream.stream(); diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index 70b00fb62..a1cc69d97 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -11,7 +11,8 @@ namespace facebook::torchcodec { namespace { -using DeviceInterfaceMap = std::map; +using DeviceInterfaceMap = + std::map; static std::mutex g_interface_mutex; DeviceInterfaceMap& getDeviceMap() { @@ -30,50 +31,79 @@ std::string getDeviceType(const std::string& device) { } // namespace bool registerDeviceInterface( - torch::DeviceType deviceType, + const DeviceInterfaceKey& key, CreateDeviceInterfaceFn createInterface) { std::scoped_lock lock(g_interface_mutex); DeviceInterfaceMap& deviceMap = getDeviceMap(); TORCH_CHECK( - deviceMap.find(deviceType) == deviceMap.end(), - "Device interface already registered for ", - deviceType); - deviceMap.insert({deviceType, createInterface}); + deviceMap.find(key) == deviceMap.end(), + "Device interface already registered for device type ", + key.deviceType, + " variant '", + key.variant, + "'"); + deviceMap.insert({key, createInterface}); return true; } -torch::Device createTorchDevice(const std::string device) { +bool registerDeviceInterface( + torch::DeviceType deviceType, + CreateDeviceInterfaceFn createInterface) { + return registerDeviceInterface( + DeviceInterfaceKey(deviceType), createInterface); +} + +void validateDeviceInterface( + const std::string device, + const std::string variant) { std::scoped_lock lock(g_interface_mutex); std::string deviceType = getDeviceType(device); + DeviceInterfaceMap& deviceMap = getDeviceMap(); + // Find device interface that matches device type and variant + torch::DeviceType deviceTypeEnum = torch::Device(deviceType).type(); + auto deviceInterface = std::find_if( deviceMap.begin(), deviceMap.end(), - [&](const std::pair& arg) { - return device.rfind( - torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0; + [&](const std::pair& arg) { + return arg.first.deviceType == deviceTypeEnum && + arg.first.variant == variant; }); - TORCH_CHECK( - deviceInterface != deviceMap.end(), "Unsupported device: ", device); - return torch::Device(device); + TORCH_CHECK( + deviceInterface != deviceMap.end(), + "Unsupported device: ", + device, + " (device type: ", + deviceType, + ", variant: ", + variant, + ")"); } std::unique_ptr createDeviceInterface( - const torch::Device& device) { - auto deviceType = device.type(); + const torch::Device& device, + const std::string_view variant) { + DeviceInterfaceKey key(device.type(), variant); std::scoped_lock lock(g_interface_mutex); DeviceInterfaceMap& deviceMap = getDeviceMap(); - TORCH_CHECK( - deviceMap.find(deviceType) != deviceMap.end(), - "Unsupported device: ", - device); + auto it = deviceMap.find(key); + if (it != deviceMap.end()) { + return std::unique_ptr(it->second(device)); + } - return std::unique_ptr(deviceMap[deviceType](device)); + TORCH_CHECK( + false, + "No device interface found for device type: ", + device.type(), + " variant: '", + variant, + "'"); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 9a7288eb0..b36255948 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -17,6 +17,24 @@ namespace facebook::torchcodec { +// Key for device interface registration with device type + variant support +struct DeviceInterfaceKey { + torch::DeviceType deviceType; + std::string_view variant = "default"; // e.g., "default", "beta", etc. + + bool operator<(const DeviceInterfaceKey& other) const { + if (deviceType != other.deviceType) { + return deviceType < other.deviceType; + } + return variant < other.variant; + } + + explicit DeviceInterfaceKey(torch::DeviceType type) : deviceType(type) {} + + DeviceInterfaceKey(torch::DeviceType type, const std::string_view& var) + : deviceType(type), variant(var) {} +}; + class DeviceInterface { public: DeviceInterface(const torch::Device& device) : device_(device) {} @@ -27,11 +45,17 @@ class DeviceInterface { return device_; }; - virtual std::optional findCodec(const AVCodecID& codecId) = 0; + virtual std::optional findCodec( + [[maybe_unused]] const AVCodecID& codecId) { + return std::nullopt; + }; // Initialize the hardware device that is specified in `device`. Some builds // support CUDA and others only support CPU. - virtual void initializeContext(AVCodecContext* codecContext) = 0; + virtual void initializeContext( + [[maybe_unused]] AVCodecContext* codecContext) {} + + virtual void initializeInterface([[maybe_unused]] AVStream* stream) {} virtual void convertAVFrameToFrameOutput( const VideoStreamOptions& videoStreamOptions, @@ -40,6 +64,53 @@ class DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; + // ------------------------------------------ + // Extension points for custom decoding paths + // ------------------------------------------ + + // Override to return true if this device interface can decode packets + // directly + virtual bool canDecodePacketDirectly() const { + return false; + } + + // Moral equivalent of avcodec_send_packet() + // Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or + // other AVERROR on failure + virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) { + TORCH_CHECK( + false, + "Send/receive packet decoding not implemented for this device interface"); + return AVERROR(ENOSYS); + } + + // Moral equivalent of avcodec_receive_frame() + // Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready, + // AVERROR_EOF if end of stream, or other AVERROR on failure + virtual int receiveFrame( + [[maybe_unused]] UniqueAVFrame& avFrame, + [[maybe_unused]] int64_t desiredPts) { + TORCH_CHECK( + false, + "Send/receive packet decoding not implemented for this device interface"); + return AVERROR(ENOSYS); + } + + // Flush remaining frames from decoder + virtual void flush() { + // Default implementation is no-op for standard decoders + // Custom decoders can override this method + } + + // Apply bitstream filter if needed, returns pointer to packet to use + // Default implementation returns the original packet (no filtering) + virtual ReferenceAVPacket* applyBSF( + ReferenceAVPacket& packet, + [[maybe_unused]] AutoAVPacket& filteredAutoPacket, + [[maybe_unused]] ReferenceAVPacket& filteredPacket) { + return &packet; // No filtering by default + } + protected: torch::Device device_; }; @@ -47,13 +118,23 @@ class DeviceInterface { using CreateDeviceInterfaceFn = std::function; +bool registerDeviceInterface( + const DeviceInterfaceKey& key, + const CreateDeviceInterfaceFn createInterface); + +// Backward-compatible registration function when variant is "default" +// TODONVDEC P2 We only need this if someone in the wild has already started +// registering their own interfaces. Ask Dmitry. bool registerDeviceInterface( torch::DeviceType deviceType, const CreateDeviceInterfaceFn createInterface); -torch::Device createTorchDevice(const std::string device); +void validateDeviceInterface( + const std::string device, + const std::string variant); std::unique_ptr createDeviceInterface( - const torch::Device& device); + const torch::Device& device, + const std::string_view variant = "default"); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 9ce7a4deb..0d37e0b94 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -56,6 +56,14 @@ int64_t getDuration(const UniqueAVFrame& avFrame) { #endif } +void setDuration(const UniqueAVFrame& avFrame, int64_t duration) { +#if LIBAVUTIL_VERSION_MAJOR < 58 + avFrame->pkt_duration = duration; +#else + avFrame->duration = duration; +#endif +} + const int* getSupportedSampleRates(const AVCodec& avCodec) { const int* supportedSampleRates = nullptr; #if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1 diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 179c7464b..2a8165370 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -12,6 +12,7 @@ extern "C" { #include +#include #include #include #include @@ -86,6 +87,8 @@ using UniqueSwrContext = std::unique_ptr>; using UniqueAVAudioFifo = std:: unique_ptr>; +using UniqueAVBSFContext = + std::unique_ptr>; using UniqueAVBufferRef = std::unique_ptr>; using UniqueAVBufferSrcParameters = std::unique_ptr< @@ -161,6 +164,7 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode); // struct member representing duration has changed across the versions we // support. int64_t getDuration(const UniqueAVFrame& frame); +void setDuration(const UniqueAVFrame& frame, int64_t duration); const int* getSupportedSampleRates(const AVCodec& avCodec); const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec); diff --git a/src/torchcodec/_core/NVDECCache.cpp b/src/torchcodec/_core/NVDECCache.cpp new file mode 100644 index 000000000..dbfb44bd2 --- /dev/null +++ b/src/torchcodec/_core/NVDECCache.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/NVDECCache.h" + +#include // For cudaGetDevice + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +NVDECCache& NVDECCache::GetCache(int deviceIndex) { + const int MAX_CUDA_GPUS = 128; + TORCH_CHECK( + deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS, + "Invalid device index = ", + deviceIndex); + static NVDECCache cacheInstances[MAX_CUDA_GPUS]; + if (deviceIndex == -1) { + // TODO NVDEC P3: Unify with existing getNonNegativeDeviceIndex() + TORCH_CHECK( + cudaGetDevice(&deviceIndex) == cudaSuccess, + "Failed to get current CUDA device."); + } + return cacheInstances[deviceIndex]; +} + +UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) { + CacheKey key(videoFormat); + std::lock_guard lock(cacheLock_); + + auto it = cache_.find(key); + if (it != cache_.end()) { + auto decoder = std::move(it->second); + cache_.erase(it); + return decoder; + } + + return nullptr; +} + +bool NVDECCache::returnDecoder( + CUVIDEOFORMAT* videoFormat, + UniqueCUvideodecoder decoder) { + if (!decoder) { + return false; + } + + CacheKey key(videoFormat); + std::lock_guard lock(cacheLock_); + + if (cache_.size() >= MAX_CACHE_SIZE) { + return false; + } + + cache_[key] = std::move(decoder); + return true; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/NVDECCache.h b/src/torchcodec/_core/NVDECCache.h new file mode 100644 index 000000000..5618100f3 --- /dev/null +++ b/src/torchcodec/_core/NVDECCache.h @@ -0,0 +1,102 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include +#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" +#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" + +namespace facebook::torchcodec { + +// This file implements a cache for NVDEC decoders. +// TODONVDEC P3: Consider merging this with Cache.h. The main difference is that +// this NVDEC Cache involves a cache key (the decoder parameters). + +struct CUvideoDecoderDeleter { + void operator()(CUvideodecoder decoder) const { + if (decoder) { + cuvidDestroyDecoder(decoder); + } + } +}; + +using UniqueCUvideodecoder = std::unique_ptr; + +// A per-device cache for NVDEC decoders. There is one instance of this class +// per GPU device, and it is accessed through the static GetCache() method. +class NVDECCache { + public: + static NVDECCache& GetCache(int deviceIndex); + + // Get decoder from cache - returns nullptr if none available + UniqueCUvideodecoder getDecoder(CUVIDEOFORMAT* videoFormat); + + // Return decoder to cache - returns true if added to cache + bool returnDecoder(CUVIDEOFORMAT* videoFormat, UniqueCUvideodecoder decoder); + + private: + // Cache key struct: a decoder can be reused and taken from the cache only if + // all these parameters match. + struct CacheKey { + cudaVideoCodec codecType; + unsigned width; + unsigned height; + cudaVideoChromaFormat chromaFormat; + unsigned int bitDepthLumaMinus8; + unsigned char numDecodeSurfaces; + + CacheKey() = delete; + + explicit CacheKey(CUVIDEOFORMAT* videoFormat) + : codecType(videoFormat->codec), + width(videoFormat->coded_width), + height(videoFormat->coded_height), + chromaFormat(videoFormat->chroma_format), + bitDepthLumaMinus8(videoFormat->bit_depth_luma_minus8), + numDecodeSurfaces(videoFormat->min_num_decode_surfaces) {} + + CacheKey(const CacheKey&) = default; + CacheKey& operator=(const CacheKey&) = default; + + // TODONVDEC P2: we only implement operator< which is enough for std::map, + // but: + // - we should consider using std::unordered_map + // - we should consider a more sophisticated and potentially less strict + // cache key comparison logic + bool operator<(const CacheKey& other) const { + return std::tie( + codecType, + width, + height, + chromaFormat, + bitDepthLumaMinus8, + numDecodeSurfaces) < + std::tie( + other.codecType, + other.width, + other.height, + other.chromaFormat, + other.bitDepthLumaMinus8, + other.numDecodeSurfaces); + } + }; + + NVDECCache() = default; + ~NVDECCache() = default; + + std::map cache_; + std::mutex cacheLock_; + + // Max number of cached decoders, per device + static constexpr int MAX_CACHE_SIZE = 20; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index b52556e93..2ce79cfa6 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -384,6 +384,7 @@ void SingleStreamDecoder::addStream( int streamIndex, AVMediaType mediaType, const torch::Device& device, + const std::string_view deviceVariant, std::optional ffmpegThreadCount) { TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, @@ -412,7 +413,7 @@ void SingleStreamDecoder::addStream( streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; - deviceInterface_ = createDeviceInterface(device); + deviceInterface_ = createDeviceInterface(device, deviceVariant); // This should never happen, checking just to be safe. TORCH_CHECK( @@ -446,6 +447,7 @@ void SingleStreamDecoder::addStream( if (mediaType == AVMEDIA_TYPE_VIDEO) { if (deviceInterface_) { deviceInterface_->initializeContext(codecContext); + deviceInterface_->initializeInterface(streamInfo.stream); } } @@ -453,6 +455,7 @@ void SingleStreamDecoder::addStream( TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); codecContext->time_base = streamInfo.stream->time_base; + containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = std::string(avcodec_get_name(codecContext->codec_id)); @@ -475,6 +478,7 @@ void SingleStreamDecoder::addVideoStream( streamIndex, AVMEDIA_TYPE_VIDEO, videoStreamOptions.device, + videoStreamOptions.deviceVariant, videoStreamOptions.ffmpegThreadCount); auto& streamMetadata = @@ -1105,6 +1109,10 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { decodeStats_.numFlushes++; avcodec_flush_buffers(streamInfo.codecContext.get()); + + if (deviceInterface_) { + deviceInterface_->flush(); + } } // -------------------------------------------------------------------------- @@ -1123,15 +1131,26 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( } StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - - // Need to get the next frame or error from PopFrame. UniqueAVFrame avFrame(av_frame_alloc()); AutoAVPacket autoAVPacket; int status = AVSUCCESS; bool reachedEOF = false; + + // TODONVDEC P2: Instead of defining useCustomInterface and rely on if/else + // blocks to dispatch to the interface or to FFmpeg, consider *always* + // dispatching to the interface. The default implementation of the interface's + // receiveFrame and sendPacket could just be calling avcodec_receive_frame and + // avcodec_send_packet. This would make the decoding loop even more generic. + bool useCustomInterface = + deviceInterface_ && deviceInterface_->canDecodePacketDirectly(); + while (true) { - status = - avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); + if (useCustomInterface) { + status = deviceInterface_->receiveFrame(avFrame, cursor_); + } else { + status = + avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); + } if (status != AVSUCCESS && status != AVERROR(EAGAIN)) { // Non-retriable error @@ -1154,7 +1173,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( if (reachedEOF) { // We don't have any more packets to receive. So keep on pulling frames - // from its internal buffers. + // from decoder's internal buffers. continue; } @@ -1166,11 +1185,18 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( decodeStats_.numPacketsRead++; if (status == AVERROR_EOF) { - // End of file reached. We must drain the codec by sending a nullptr - // packet. - status = avcodec_send_packet( - streamInfo.codecContext.get(), - /*avpkt=*/nullptr); + // End of file reached. We must drain the decoder + if (useCustomInterface) { + AutoAVPacket eofAutoPacket; + ReferenceAVPacket eofPacket(eofAutoPacket); + eofPacket->data = nullptr; + eofPacket->size = 0; + status = deviceInterface_->sendPacket(eofPacket); + } else { + status = avcodec_send_packet( + streamInfo.codecContext.get(), + /*avpkt=*/nullptr); + } TORCH_CHECK( status >= AVSUCCESS, "Could not flush decoder: ", @@ -1195,7 +1221,19 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. - status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); + if (useCustomInterface) { + // TODONVDEC P0: + // - cleanup this raw pointer / reference monstruosity. + // - don't even expose applyBSF in the interface. This should just be part + // of sendPacket(). + AutoAVPacket filteredAutoPacket; + ReferenceAVPacket filteredPacket(filteredAutoPacket); + ReferenceAVPacket* packetToSend = deviceInterface_->applyBSF( + packet, filteredAutoPacket, filteredPacket); + status = deviceInterface_->sendPacket(*packetToSend); + } else { + status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); + } TORCH_CHECK( status >= AVSUCCESS, "Could not push packet to decoder: ", diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 56bb8bb58..779acd273 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -318,6 +318,7 @@ class SingleStreamDecoder { int streamIndex, AVMediaType mediaType, const torch::Device& device = torch::kCPU, + const std::string_view deviceVariant = "default", std::optional ffmpegThreadCount = std::nullopt); // Returns the "best" stream index for a given media type. The "best" is diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 19cc5126c..65f2782a8 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace facebook::torchcodec { @@ -38,6 +39,8 @@ struct VideoStreamOptions { std::optional colorConversionLibrary; // By default we use CPU for decoding for both C++ and python users. torch::Device device = torch::kCPU; + // Device variant (e.g., "default", "beta", etc.) + std::string_view deviceVariant = "default"; // Encoding options std::optional bitRate; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index a865bdaed..1e8945cbd 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -43,9 +43,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? device_variant=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? device_variant=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -258,6 +258,7 @@ void _add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, + std::optional device_variant = std::nullopt, std::optional> custom_frame_mappings = std::nullopt, std::optional color_conversion_library = std::nullopt) { @@ -287,9 +288,18 @@ void _add_video_stream( ". color_conversion_library must be either filtergraph or swscale."); } } - if (device.has_value()) { - videoStreamOptions.device = createTorchDevice(std::string(device.value())); + + if (!device.has_value()) { + device = "cpu"; + } + if (!device_variant.has_value()) { + device_variant = "default"; } + validateDeviceInterface(std::string(*device), std::string(*device_variant)); + + videoStreamOptions.device = torch::Device(std::string(*device)); + videoStreamOptions.deviceVariant = *device_variant; + std::optional converted_mappings = custom_frame_mappings.has_value() ? std::make_optional(makeFrameMappings(custom_frame_mappings.value())) @@ -308,6 +318,7 @@ void add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::optional device = std::nullopt, + std::optional device_variant = std::nullopt, const std::optional>& custom_frame_mappings = std::nullopt) { _add_video_stream( @@ -318,6 +329,7 @@ void add_video_stream( dimension_order, stream_index, device, + device_variant, custom_frame_mappings); } diff --git a/src/torchcodec/_core/nvcuvid_include/cuviddec.h b/src/torchcodec/_core/nvcuvid_include/cuviddec.h new file mode 100644 index 000000000..4e70fe5a4 --- /dev/null +++ b/src/torchcodec/_core/nvcuvid_include/cuviddec.h @@ -0,0 +1,1374 @@ +/* + * This copyright notice applies to this header file only: + * + * Copyright (c) 2010-2024 NVIDIA Corporation + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the software, and to permit persons to whom the + * software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/*****************************************************************************************************/ +//! \file cuviddec.h +//! NVDECODE API provides video decoding interface to NVIDIA GPU devices. +//! This file contains constants, structure definitions and function prototypes +//! used for decoding. +/*****************************************************************************************************/ + +#if !defined(__CUDA_VIDEO_H__) +#define __CUDA_VIDEO_H__ + +#ifndef __cuda_cuda_h__ +#include +#endif // __cuda_cuda_h__ + +#if defined(_WIN64) || defined(__LP64__) || defined(__x86_64) || \ + defined(AMD64) || defined(_M_AMD64) +#if (CUDA_VERSION >= 3020) && \ + (!defined(CUDA_FORCE_API_VERSION) || (CUDA_FORCE_API_VERSION >= 3020)) +#define __CUVID_DEVPTR64 +#endif +#endif + +#if defined(__cplusplus) +extern "C" { +#endif /* __cplusplus */ + +typedef void* CUvideodecoder; +typedef struct _CUcontextlock_st* CUvideoctxlock; + +/*********************************************************************************/ +//! \enum cudaVideoCodec +//! Video codec enums +//! These enums are used in CUVIDDECODECREATEINFO and CUVIDDECODECAPS structures +/*********************************************************************************/ +typedef enum cudaVideoCodec_enum { + cudaVideoCodec_MPEG1 = 0, /**< MPEG1 */ + cudaVideoCodec_MPEG2, /**< MPEG2 */ + cudaVideoCodec_MPEG4, /**< MPEG4 */ + cudaVideoCodec_VC1, /**< VC1 */ + cudaVideoCodec_H264, /**< H264 */ + cudaVideoCodec_JPEG, /**< JPEG */ + cudaVideoCodec_H264_SVC, /**< H264-SVC */ + cudaVideoCodec_H264_MVC, /**< H264-MVC */ + cudaVideoCodec_HEVC, /**< HEVC */ + cudaVideoCodec_VP8, /**< VP8 */ + cudaVideoCodec_VP9, /**< VP9 */ + cudaVideoCodec_AV1, /**< AV1 */ + cudaVideoCodec_NumCodecs, /**< Max codecs */ + // Uncompressed YUV + cudaVideoCodec_YUV420 = + (('I' << 24) | ('Y' << 16) | ('U' << 8) | ('V')), /**< Y,U,V (4:2:0) */ + cudaVideoCodec_YV12 = + (('Y' << 24) | ('V' << 16) | ('1' << 8) | ('2')), /**< Y,V,U (4:2:0) */ + cudaVideoCodec_NV12 = + (('N' << 24) | ('V' << 16) | ('1' << 8) | ('2')), /**< Y,UV (4:2:0) */ + cudaVideoCodec_YUYV = + (('Y' << 24) | ('U' << 16) | ('Y' << 8) | + ('V')), /**< YUYV/YUY2 (4:2:2) */ + cudaVideoCodec_UYVY = + (('U' << 24) | ('Y' << 16) | ('V' << 8) | ('Y')) /**< UYVY (4:2:2) */ +} cudaVideoCodec; + +/*********************************************************************************/ +//! \enum cudaVideoSurfaceFormat +//! Video surface format enums used for output format of decoded output +//! These enums are used in CUVIDDECODECREATEINFO structure +/*********************************************************************************/ +typedef enum cudaVideoSurfaceFormat_enum { + cudaVideoSurfaceFormat_NV12 = + 0, /**< Semi-Planar YUV [Y plane followed by interleaved UV plane] */ + cudaVideoSurfaceFormat_P016 = + 1, /**< 16 bit Semi-Planar YUV [Y plane followed by interleaved UV plane]. + Can be used for 10 bit(6LSB bits 0), 12 bit (4LSB bits 0) */ + cudaVideoSurfaceFormat_YUV444 = + 2, /**< Planar YUV [Y plane followed by U and V planes] */ + cudaVideoSurfaceFormat_YUV444_16Bit = + 3, /**< 16 bit Planar YUV [Y plane followed by U and V planes]. + Can be used for 10 bit(6LSB bits 0), 12 bit (4LSB bits 0) */ +} cudaVideoSurfaceFormat; + +/******************************************************************************************************************/ +//! \enum cudaVideoDeinterlaceMode +//! Deinterlacing mode enums +//! These enums are used in CUVIDDECODECREATEINFO structure +//! Use cudaVideoDeinterlaceMode_Weave for progressive content and for content +//! that doesn't need deinterlacing cudaVideoDeinterlaceMode_Adaptive needs more +//! video memory than other DImodes +/******************************************************************************************************************/ +typedef enum cudaVideoDeinterlaceMode_enum { + cudaVideoDeinterlaceMode_Weave = + 0, /**< Weave both fields (no deinterlacing) */ + cudaVideoDeinterlaceMode_Bob, /**< Drop one field */ + cudaVideoDeinterlaceMode_Adaptive /**< Adaptive deinterlacing */ +} cudaVideoDeinterlaceMode; + +/**************************************************************************************************************/ +//! \enum cudaVideoChromaFormat +//! Chroma format enums +//! These enums are used in CUVIDDECODECREATEINFO and CUVIDDECODECAPS structures +/**************************************************************************************************************/ +typedef enum cudaVideoChromaFormat_enum { + cudaVideoChromaFormat_Monochrome = 0, /**< MonoChrome */ + cudaVideoChromaFormat_420, /**< YUV 4:2:0 */ + cudaVideoChromaFormat_422, /**< YUV 4:2:2 */ + cudaVideoChromaFormat_444 /**< YUV 4:4:4 */ +} cudaVideoChromaFormat; + +/*************************************************************************************************************/ +//! \enum cudaVideoCreateFlags +//! Decoder flag enums to select preferred decode path +//! cudaVideoCreate_Default and cudaVideoCreate_PreferCUVID are most optimized, +//! use these whenever possible +/*************************************************************************************************************/ +typedef enum cudaVideoCreateFlags_enum { + cudaVideoCreate_Default = + 0x00, /**< Default operation mode: use dedicated video engines */ + cudaVideoCreate_PreferCUDA = 0x01, /**< Use CUDA-based decoder (requires valid + vidLock object for multi-threading) */ + cudaVideoCreate_PreferDXVA = 0x02, /**< Go through DXVA internally if possible + (requires D3D9 interop) */ + cudaVideoCreate_PreferCUVID = + 0x04 /**< Use dedicated video engines directly */ +} cudaVideoCreateFlags; + +/*************************************************************************/ +//! \enum cuvidDecodeStatus +//! Decode status enums +//! These enums are used in CUVIDGETDECODESTATUS structure +/*************************************************************************/ +typedef enum cuvidDecodeStatus_enum { + cuvidDecodeStatus_Invalid = 0, // Decode status is not valid + cuvidDecodeStatus_InProgress = 1, // Decode is in progress + cuvidDecodeStatus_Success = 2, // Decode is completed without any errors + // 3 to 7 enums are reserved for future use + cuvidDecodeStatus_Error = + 8, // Decode is completed with an error (error is not concealed) + cuvidDecodeStatus_Error_Concealed = + 9, // Decode is completed with an error and error is concealed +} cuvidDecodeStatus; + +/**************************************************************************************************************/ +//! \struct CUVIDDECODECAPS; +//! This structure is used in cuvidGetDecoderCaps API +/**************************************************************************************************************/ +typedef struct _CUVIDDECODECAPS { + cudaVideoCodec eCodecType; /**< IN: cudaVideoCodec_XXX */ + cudaVideoChromaFormat eChromaFormat; /**< IN: cudaVideoChromaFormat_XXX */ + unsigned int nBitDepthMinus8; /**< IN: The Value "BitDepth minus 8" */ + unsigned int reserved1[3]; /**< Reserved for future use - set to zero */ + + unsigned char + bIsSupported; /**< OUT: 1 if codec supported, 0 if not supported */ + unsigned char + nNumNVDECs; /**< OUT: Number of NVDECs that can support IN params */ + unsigned short nOutputFormatMask; /**< OUT: each bit represents corresponding + cudaVideoSurfaceFormat enum */ + unsigned int nMaxWidth; /**< OUT: Max supported coded width in pixels */ + unsigned int nMaxHeight; /**< OUT: Max supported coded height in pixels */ + unsigned int nMaxMBCount; /**< OUT: Max supported macroblock count + CodedWidth*CodedHeight/256 must be <= + nMaxMBCount */ + unsigned short nMinWidth; /**< OUT: Min supported coded width in pixels */ + unsigned short nMinHeight; /**< OUT: Min supported coded height in pixels */ + unsigned char + bIsHistogramSupported; /**< OUT: 1 if Y component histogram output is + supported, 0 if not Note: histogram is computed + on original picture data before any + post-processing like scaling, cropping, etc. is + applied */ + unsigned char nCounterBitDepth; /**< OUT: histogram counter bit depth */ + unsigned short nMaxHistogramBins; /**< OUT: Max number of histogram bins */ + unsigned int reserved3[10]; /**< Reserved for future use - set to zero */ +} CUVIDDECODECAPS; + +/**************************************************************************************************************/ +//! \struct CUVIDDECODECREATEINFO +//! This structure is used in cuvidCreateDecoder API +/**************************************************************************************************************/ +typedef struct _CUVIDDECODECREATEINFO { + unsigned long ulWidth; /**< IN: Coded sequence width in pixels */ + unsigned long ulHeight; /**< IN: Coded sequence height in pixels */ + unsigned long ulNumDecodeSurfaces; /**< IN: Maximum number of internal decode + surfaces */ + cudaVideoCodec CodecType; /**< IN: cudaVideoCodec_XXX */ + cudaVideoChromaFormat ChromaFormat; /**< IN: cudaVideoChromaFormat_XXX */ + unsigned long ulCreationFlags; /**< IN: Decoder creation flags + (cudaVideoCreateFlags_XXX) */ + unsigned long bitDepthMinus8; /**< IN: The value "BitDepth minus 8" */ + unsigned long + ulIntraDecodeOnly; /**< IN: Set 1 only if video has all intra frames + (default value is 0). This will optimize video + memory for Intra frames only decoding. The support + is limited to specific codecs - H264, HEVC, VP9, the + flag will be ignored for codecs which are not + supported. However decoding might fail if the flag + is enabled in case of supported codecs for regular + bit streams having P and/or B frames. */ + unsigned long ulMaxWidth; /**< IN: Coded sequence max width in pixels used + with reconfigure Decoder */ + unsigned long ulMaxHeight; /**< IN: Coded sequence max height in pixels used + with reconfigure Decoder */ + unsigned long Reserved1; /**< Reserved for future use - set to zero */ + + /** + * IN: area of the frame that should be displayed + */ + struct { + short left; + short top; + short right; + short bottom; + } display_area; + + cudaVideoSurfaceFormat OutputFormat; /**< IN: cudaVideoSurfaceFormat_XXX */ + cudaVideoDeinterlaceMode + DeinterlaceMode; /**< IN: cudaVideoDeinterlaceMode_XXX */ + unsigned long ulTargetWidth; /**< IN: Post-processed output width (Should be + aligned to 2) */ + unsigned long ulTargetHeight; /**< IN: Post-processed output height (Should be + aligned to 2) */ + unsigned long ulNumOutputSurfaces; /**< IN: Maximum number of output surfaces + simultaneously mapped */ + CUvideoctxlock vidLock; /**< IN: If non-NULL, context lock used for + synchronizing ownership of the cuda context. Needed + for cudaVideoCreate_PreferCUDA decode */ + + /** + * IN: target rectangle in the output frame (for aspect ratio conversion) + * if a null rectangle is specified, {0,0,ulTargetWidth,ulTargetHeight} will + * be used + */ + struct { + short left; + short top; + short right; + short bottom; + } target_rect; + + unsigned long + enableHistogram; /**< IN: enable histogram output, if supported */ + unsigned long Reserved2[4]; /**< Reserved for future use - set to zero */ +} CUVIDDECODECREATEINFO; + +/*********************************************************/ +//! \struct CUVIDH264DPBENTRY +//! H.264 DPB entry +//! This structure is used in CUVIDH264PICPARAMS structure +/*********************************************************/ +typedef struct _CUVIDH264DPBENTRY { + int PicIdx; /**< picture index of reference frame */ + int FrameIdx; /**< frame_num(short-term) or LongTermFrameIdx(long-term) */ + int is_long_term; /**< 0=short term reference, 1=long term reference */ + int not_existing; /**< non-existing reference frame (corresponding PicIdx + should be set to -1) */ + int used_for_reference; /**< 0=unused, 1=top_field, 2=bottom_field, + 3=both_fields */ + int FieldOrderCnt[2]; /**< field order count of top and bottom fields */ +} CUVIDH264DPBENTRY; + +/************************************************************/ +//! \struct CUVIDH264MVCEXT +//! H.264 MVC picture parameters ext +//! This structure is used in CUVIDH264PICPARAMS structure +/************************************************************/ +typedef struct _CUVIDH264MVCEXT { + int num_views_minus1; /**< Max number of coded views minus 1 in video : Range + - 0 to 1023 */ + int view_id; /**< view identifier */ + unsigned char + inter_view_flag; /**< 1 if used for inter-view prediction, 0 if not */ + unsigned char num_inter_view_refs_l0; /**< number of inter-view ref pics in + RefPicList0 */ + unsigned char num_inter_view_refs_l1; /**< number of inter-view ref pics in + RefPicList1 */ + unsigned char MVCReserved8Bits; /**< Reserved bits */ + int InterViewRefsL0[16]; /**< view id of the i-th view component for + inter-view prediction in RefPicList0 */ + int InterViewRefsL1[16]; /**< view id of the i-th view component for + inter-view prediction in RefPicList1 */ +} CUVIDH264MVCEXT; + +/*********************************************************/ +//! \struct CUVIDH264SVCEXT +//! H.264 SVC picture parameters ext +//! This structure is used in CUVIDH264PICPARAMS structure +/*********************************************************/ +typedef struct _CUVIDH264SVCEXT { + unsigned char profile_idc; + unsigned char level_idc; + unsigned char DQId; + unsigned char DQIdMax; + unsigned char disable_inter_layer_deblocking_filter_idc; + unsigned char ref_layer_chroma_phase_y_plus1; + signed char inter_layer_slice_alpha_c0_offset_div2; + signed char inter_layer_slice_beta_offset_div2; + + unsigned short DPBEntryValidFlag; + unsigned char inter_layer_deblocking_filter_control_present_flag; + unsigned char extended_spatial_scalability_idc; + unsigned char adaptive_tcoeff_level_prediction_flag; + unsigned char slice_header_restriction_flag; + unsigned char chroma_phase_x_plus1_flag; + unsigned char chroma_phase_y_plus1; + + unsigned char tcoeff_level_prediction_flag; + unsigned char constrained_intra_resampling_flag; + unsigned char ref_layer_chroma_phase_x_plus1_flag; + unsigned char store_ref_base_pic_flag; + unsigned char Reserved8BitsA; + unsigned char Reserved8BitsB; + + short scaled_ref_layer_left_offset; + short scaled_ref_layer_top_offset; + short scaled_ref_layer_right_offset; + short scaled_ref_layer_bottom_offset; + unsigned short Reserved16Bits; + struct _CUVIDPICPARAMS* + pNextLayer; /**< Points to the picparams for the next layer to be decoded. + Linked list ends at the target layer. */ + int bRefBaseLayer; /**< whether to store ref base pic */ +} CUVIDH264SVCEXT; + +/******************************************************/ +//! \struct CUVIDH264PICPARAMS +//! H.264 picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/******************************************************/ +typedef struct _CUVIDH264PICPARAMS { + // SPS + int log2_max_frame_num_minus4; + int pic_order_cnt_type; + int log2_max_pic_order_cnt_lsb_minus4; + int delta_pic_order_always_zero_flag; + int frame_mbs_only_flag; + int direct_8x8_inference_flag; + int num_ref_frames; // NOTE: shall meet level 4.1 restrictions + unsigned char residual_colour_transform_flag; + unsigned char bit_depth_luma_minus8; // Must be 0 (only 8-bit supported) + unsigned char bit_depth_chroma_minus8; // Must be 0 (only 8-bit supported) + unsigned char qpprime_y_zero_transform_bypass_flag; + // PPS + int entropy_coding_mode_flag; + int pic_order_present_flag; + int num_ref_idx_l0_active_minus1; + int num_ref_idx_l1_active_minus1; + int weighted_pred_flag; + int weighted_bipred_idc; + int pic_init_qp_minus26; + int deblocking_filter_control_present_flag; + int redundant_pic_cnt_present_flag; + int transform_8x8_mode_flag; + int MbaffFrameFlag; + int constrained_intra_pred_flag; + int chroma_qp_index_offset; + int second_chroma_qp_index_offset; + int ref_pic_flag; + int frame_num; + int CurrFieldOrderCnt[2]; + // DPB + CUVIDH264DPBENTRY dpb[16]; // List of reference frames within the DPB + // Quantization Matrices (raster-order) + unsigned char WeightScale4x4[6][16]; + unsigned char WeightScale8x8[2][64]; + // FMO/ASO + unsigned char fmo_aso_enable; + unsigned char num_slice_groups_minus1; + unsigned char slice_group_map_type; + signed char pic_init_qs_minus26; + unsigned int slice_group_change_rate_minus1; + + union { + unsigned long long slice_group_map_addr; + const unsigned char* pMb2SliceGroupMap; + } fmo; + + unsigned int Reserved[12]; + + // SVC/MVC + union { + CUVIDH264MVCEXT mvcext; + CUVIDH264SVCEXT svcext; + }; +} CUVIDH264PICPARAMS; + +/********************************************************/ +//! \struct CUVIDMPEG2PICPARAMS +//! MPEG-2 picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/********************************************************/ +typedef struct _CUVIDMPEG2PICPARAMS { + int ForwardRefIdx; // Picture index of forward reference (P/B-frames) + int BackwardRefIdx; // Picture index of backward reference (B-frames) + int picture_coding_type; + int full_pel_forward_vector; + int full_pel_backward_vector; + int f_code[2][2]; + int intra_dc_precision; + int frame_pred_frame_dct; + int concealment_motion_vectors; + int q_scale_type; + int intra_vlc_format; + int alternate_scan; + int top_field_first; + // Quantization matrices (raster order) + unsigned char QuantMatrixIntra[64]; + unsigned char QuantMatrixInter[64]; +} CUVIDMPEG2PICPARAMS; + +// MPEG-4 has VOP types instead of Picture types +#define I_VOP 0 +#define P_VOP 1 +#define B_VOP 2 +#define S_VOP 3 + +/*******************************************************/ +//! \struct CUVIDMPEG4PICPARAMS +//! MPEG-4 picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/*******************************************************/ +typedef struct _CUVIDMPEG4PICPARAMS { + int ForwardRefIdx; // Picture index of forward reference (P/B-frames) + int BackwardRefIdx; // Picture index of backward reference (B-frames) + // VOL + int video_object_layer_width; + int video_object_layer_height; + int vop_time_increment_bitcount; + int top_field_first; + int resync_marker_disable; + int quant_type; + int quarter_sample; + int short_video_header; + int divx_flags; + // VOP + int vop_coding_type; + int vop_coded; + int vop_rounding_type; + int alternate_vertical_scan_flag; + int interlaced; + int vop_fcode_forward; + int vop_fcode_backward; + int trd[2]; + int trb[2]; + // Quantization matrices (raster order) + unsigned char QuantMatrixIntra[64]; + unsigned char QuantMatrixInter[64]; + int gmc_enabled; +} CUVIDMPEG4PICPARAMS; + +/********************************************************/ +//! \struct CUVIDVC1PICPARAMS +//! VC1 picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/********************************************************/ +typedef struct _CUVIDVC1PICPARAMS { + int ForwardRefIdx; /**< Picture index of forward reference (P/B-frames) */ + int BackwardRefIdx; /**< Picture index of backward reference (B-frames) */ + int FrameWidth; /**< Actual frame width */ + int FrameHeight; /**< Actual frame height */ + // PICTURE + int intra_pic_flag; /**< Set to 1 for I,BI frames */ + int ref_pic_flag; /**< Set to 1 for I,P frames */ + int progressive_fcm; /**< Progressive frame */ + // SEQUENCE + int profile; + int postprocflag; + int pulldown; + int interlace; + int tfcntrflag; + int finterpflag; + int psf; + int multires; + int syncmarker; + int rangered; + int maxbframes; + // ENTRYPOINT + int panscan_flag; + int refdist_flag; + int extended_mv; + int dquant; + int vstransform; + int loopfilter; + int fastuvmc; + int overlap; + int quantizer; + int extended_dmv; + int range_mapy_flag; + int range_mapy; + int range_mapuv_flag; + int range_mapuv; + int rangeredfrm; // range reduction state +} CUVIDVC1PICPARAMS; + +/***********************************************************/ +//! \struct CUVIDJPEGPICPARAMS +//! JPEG picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/***********************************************************/ +typedef struct _CUVIDJPEGPICPARAMS { + int Reserved; +} CUVIDJPEGPICPARAMS; + +/*******************************************************/ +//! \struct CUVIDHEVCPICPARAMS +//! HEVC picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/*******************************************************/ +typedef struct _CUVIDHEVCPICPARAMS { + // sps + int pic_width_in_luma_samples; + int pic_height_in_luma_samples; + unsigned char log2_min_luma_coding_block_size_minus3; + unsigned char log2_diff_max_min_luma_coding_block_size; + unsigned char log2_min_transform_block_size_minus2; + unsigned char log2_diff_max_min_transform_block_size; + unsigned char pcm_enabled_flag; + unsigned char log2_min_pcm_luma_coding_block_size_minus3; + unsigned char log2_diff_max_min_pcm_luma_coding_block_size; + unsigned char pcm_sample_bit_depth_luma_minus1; + + unsigned char pcm_sample_bit_depth_chroma_minus1; + unsigned char pcm_loop_filter_disabled_flag; + unsigned char strong_intra_smoothing_enabled_flag; + unsigned char max_transform_hierarchy_depth_intra; + unsigned char max_transform_hierarchy_depth_inter; + unsigned char amp_enabled_flag; + unsigned char separate_colour_plane_flag; + unsigned char log2_max_pic_order_cnt_lsb_minus4; + + unsigned char num_short_term_ref_pic_sets; + unsigned char long_term_ref_pics_present_flag; + unsigned char num_long_term_ref_pics_sps; + unsigned char sps_temporal_mvp_enabled_flag; + unsigned char sample_adaptive_offset_enabled_flag; + unsigned char scaling_list_enable_flag; + unsigned char IrapPicFlag; + unsigned char IdrPicFlag; + + unsigned char bit_depth_luma_minus8; + unsigned char bit_depth_chroma_minus8; + // sps/pps extension fields + unsigned char log2_max_transform_skip_block_size_minus2; + unsigned char log2_sao_offset_scale_luma; + unsigned char log2_sao_offset_scale_chroma; + unsigned char high_precision_offsets_enabled_flag; + unsigned char reserved1[10]; + + // pps + unsigned char dependent_slice_segments_enabled_flag; + unsigned char slice_segment_header_extension_present_flag; + unsigned char sign_data_hiding_enabled_flag; + unsigned char cu_qp_delta_enabled_flag; + unsigned char diff_cu_qp_delta_depth; + signed char init_qp_minus26; + signed char pps_cb_qp_offset; + signed char pps_cr_qp_offset; + + unsigned char constrained_intra_pred_flag; + unsigned char weighted_pred_flag; + unsigned char weighted_bipred_flag; + unsigned char transform_skip_enabled_flag; + unsigned char transquant_bypass_enabled_flag; + unsigned char entropy_coding_sync_enabled_flag; + unsigned char log2_parallel_merge_level_minus2; + unsigned char num_extra_slice_header_bits; + + unsigned char loop_filter_across_tiles_enabled_flag; + unsigned char loop_filter_across_slices_enabled_flag; + unsigned char output_flag_present_flag; + unsigned char num_ref_idx_l0_default_active_minus1; + unsigned char num_ref_idx_l1_default_active_minus1; + unsigned char lists_modification_present_flag; + unsigned char cabac_init_present_flag; + unsigned char pps_slice_chroma_qp_offsets_present_flag; + + unsigned char deblocking_filter_override_enabled_flag; + unsigned char pps_deblocking_filter_disabled_flag; + signed char pps_beta_offset_div2; + signed char pps_tc_offset_div2; + unsigned char tiles_enabled_flag; + unsigned char uniform_spacing_flag; + unsigned char num_tile_columns_minus1; + unsigned char num_tile_rows_minus1; + + unsigned short column_width_minus1[21]; + unsigned short row_height_minus1[21]; + + // sps and pps extension HEVC-main 444 + unsigned char sps_range_extension_flag; + unsigned char transform_skip_rotation_enabled_flag; + unsigned char transform_skip_context_enabled_flag; + unsigned char implicit_rdpcm_enabled_flag; + + unsigned char explicit_rdpcm_enabled_flag; + unsigned char extended_precision_processing_flag; + unsigned char intra_smoothing_disabled_flag; + unsigned char persistent_rice_adaptation_enabled_flag; + + unsigned char cabac_bypass_alignment_enabled_flag; + unsigned char pps_range_extension_flag; + unsigned char cross_component_prediction_enabled_flag; + unsigned char chroma_qp_offset_list_enabled_flag; + + unsigned char diff_cu_chroma_qp_offset_depth; + unsigned char chroma_qp_offset_list_len_minus1; + signed char cb_qp_offset_list[6]; + + signed char cr_qp_offset_list[6]; + unsigned char reserved2[2]; + + unsigned int reserved3[8]; + + // RefPicSets + int NumBitsForShortTermRPSInSlice; + int NumDeltaPocsOfRefRpsIdx; + int NumPocTotalCurr; + int NumPocStCurrBefore; + int NumPocStCurrAfter; + int NumPocLtCurr; + int CurrPicOrderCntVal; + int RefPicIdx[16]; // [refpic] Indices of valid reference pictures (-1 if + // unused for reference) + int PicOrderCntVal[16]; // [refpic] + unsigned char IsLongTerm[16]; // [refpic] 0=not a long-term reference, + // 1=long-term reference + unsigned char + RefPicSetStCurrBefore[8]; // [0..NumPocStCurrBefore-1] -> refpic (0..15) + unsigned char + RefPicSetStCurrAfter[8]; // [0..NumPocStCurrAfter-1] -> refpic (0..15) + unsigned char RefPicSetLtCurr[8]; // [0..NumPocLtCurr-1] -> refpic (0..15) + unsigned char RefPicSetInterLayer0[8]; + unsigned char RefPicSetInterLayer1[8]; + unsigned int reserved4[12]; + + // scaling lists (diag order) + unsigned char ScalingList4x4[6][16]; // [matrixId][i] + unsigned char ScalingList8x8[6][64]; // [matrixId][i] + unsigned char ScalingList16x16[6][64]; // [matrixId][i] + unsigned char ScalingList32x32[2][64]; // [matrixId][i] + unsigned char ScalingListDCCoeff16x16[6]; // [matrixId] + unsigned char ScalingListDCCoeff32x32[2]; // [matrixId] +} CUVIDHEVCPICPARAMS; + +/***********************************************************/ +//! \struct CUVIDVP8PICPARAMS +//! VP8 picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/***********************************************************/ +typedef struct _CUVIDVP8PICPARAMS { + int width; + int height; + unsigned int first_partition_size; + // Frame Indexes + unsigned char LastRefIdx; + unsigned char GoldenRefIdx; + unsigned char AltRefIdx; + + union { + struct { + unsigned char frame_type : 1; /**< 0 = KEYFRAME, 1 = INTERFRAME */ + unsigned char version : 3; + unsigned char show_frame : 1; + unsigned char + update_mb_segmentation_data : 1; /**< Must be 0 if segmentation is not + enabled */ + unsigned char Reserved2Bits : 2; + } vp8_frame_tag; + + unsigned char wFrameTagFlags; + }; + + unsigned char Reserved1[4]; + unsigned int Reserved2[3]; +} CUVIDVP8PICPARAMS; + +/***********************************************************/ +//! \struct CUVIDVP9PICPARAMS +//! VP9 picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/***********************************************************/ +typedef struct _CUVIDVP9PICPARAMS { + unsigned int width; + unsigned int height; + + // Frame Indices + unsigned char LastRefIdx; + unsigned char GoldenRefIdx; + unsigned char AltRefIdx; + unsigned char colorSpace; + + unsigned short profile : 3; + unsigned short frameContextIdx : 2; + unsigned short frameType : 1; + unsigned short showFrame : 1; + unsigned short errorResilient : 1; + unsigned short frameParallelDecoding : 1; + unsigned short subSamplingX : 1; + unsigned short subSamplingY : 1; + unsigned short intraOnly : 1; + unsigned short allow_high_precision_mv : 1; + unsigned short refreshEntropyProbs : 1; + unsigned short reserved2Bits : 2; + + unsigned short reserved16Bits; + + unsigned char refFrameSignBias[4]; + + unsigned char bitDepthMinus8Luma; + unsigned char bitDepthMinus8Chroma; + unsigned char loopFilterLevel; + unsigned char loopFilterSharpness; + + unsigned char modeRefLfEnabled; + unsigned char log2_tile_columns; + unsigned char log2_tile_rows; + + unsigned char segmentEnabled : 1; + unsigned char segmentMapUpdate : 1; + unsigned char segmentMapTemporalUpdate : 1; + unsigned char segmentFeatureMode : 1; + unsigned char reserved4Bits : 4; + + unsigned char segmentFeatureEnable[8][4]; + short segmentFeatureData[8][4]; + unsigned char mb_segment_tree_probs[7]; + unsigned char segment_pred_probs[3]; + unsigned char reservedSegment16Bits[2]; + + int qpYAc; + int qpYDc; + int qpChDc; + int qpChAc; + + unsigned int activeRefIdx[3]; + unsigned int resetFrameContext; + unsigned int mcomp_filter_type; + unsigned int mbRefLfDelta[4]; + unsigned int mbModeLfDelta[2]; + unsigned int frameTagSize; + unsigned int offsetToDctParts; + unsigned int reserved128Bits[4]; + +} CUVIDVP9PICPARAMS; + +/***********************************************************/ +//! \struct CUVIDAV1PICPARAMS +//! AV1 picture parameters +//! This structure is used in CUVIDPICPARAMS structure +/***********************************************************/ +typedef struct _CUVIDAV1PICPARAMS { + unsigned int + width; // coded width, if superres enabled then it is upscaled width + unsigned int height; // coded height + unsigned int frame_offset; // defined as order_hint in AV1 specification + int decodePicIdx; // decoded output pic index, if film grain enabled, it will + // keep decoded (without film grain) output It can be used + // as reference frame for future frames + + // sequence header + unsigned int profile : 3; // 0 = profile0, 1 = profile1, 2 = profile2 + unsigned int + use_128x128_superblock : 1; // superblock size 0:64x64, 1: 128x128 + unsigned int + subsampling_x : 1; // (subsampling_x, _y) 1,1 = 420, 1,0 = 422, 0,0 = 444 + unsigned int subsampling_y : 1; + unsigned int mono_chrome : 1; // for monochrome content, mono_chrome = 1 and + // (subsampling_x, _y) should be 1,1 + unsigned int bit_depth_minus8 : 4; // bit depth minus 8 + unsigned int enable_filter_intra : 1; // tool enable in seq level, 0 : disable + // 1: frame header control + unsigned int enable_intra_edge_filter : 1; // intra edge filtering process, 0 + // : disable 1: enabled + unsigned int + enable_interintra_compound : 1; // interintra, 0 : not present 1: present + unsigned int + enable_masked_compound : 1; // 1: mode info for inter blocks may contain + // the syntax element compound_type. 0: syntax + // element compound_type will not be present + unsigned int enable_dual_filter : 1; // vertical and horiz filter selection, + // 1: enable and 0: disable + unsigned int enable_order_hint : 1; // order hint, and related tools, 1: + // enable and 0: disable + unsigned int order_hint_bits_minus1 : 3; // is used to compute OrderHintBits + unsigned int + enable_jnt_comp : 1; // joint compound modes, 1: enable and 0: disable + unsigned int enable_superres : 1; // superres in seq level, 0 : disable 1: + // frame level control + unsigned int enable_cdef : 1; // cdef filtering in seq level, 0 : disable 1: + // frame level control + unsigned int + enable_restoration : 1; // loop restoration filtering in seq level, 0 : + // disable 1: frame level control + unsigned int enable_fgs : 1; // defined as film_grain_params_present in AV1 + // specification + unsigned int reserved0_7bits : 7; // reserved bits; must be set to 0 + + // frame header + unsigned int + frame_type : 2; // 0:Key frame, 1:Inter frame, 2:intra only, 3:s-frame + unsigned int show_frame : 1; // show_frame = 1 implies that frame should be + // immediately output once decoded + unsigned int disable_cdf_update : 1; // CDF update during symbol decoding, 1: + // disabled, 0: enabled + unsigned int allow_screen_content_tools : 1; // 1: intra blocks may use + // palette encoding, 0: palette + // encoding is never used + unsigned int force_integer_mv : 1; // 1: motion vectors will always be + // integers, 0: can contain fractional bits + unsigned int coded_denom : 3; // coded_denom of the superres scale as + // specified in AV1 specification + unsigned int allow_intrabc : 1; // 1: intra block copy may be used, 0: intra + // block copy is not allowed + unsigned int allow_high_precision_mv : 1; // 1/8 precision mv enable + unsigned int + interp_filter : 3; // interpolation filter. Refer to section 6.8.9 of the + // AV1 specification Version 1.0.0 with Errata 1 + unsigned int + switchable_motion_mode : 1; // defined as is_motion_mode_switchable in AV1 + // specification + unsigned int use_ref_frame_mvs : 1; // 1: current frame can use the previous + // frame mv information, 0: will not use. + unsigned int disable_frame_end_update_cdf : 1; // 1: indicates that the end of + // frame CDF update is disabled + unsigned int delta_q_present : 1; // quantizer index delta values are present + // in the block level + unsigned int delta_q_res : 2; // left shift which should be applied to decoded + // quantizer index delta values + unsigned int using_qmatrix : 1; // 1: quantizer matrix will be used to compute + // quantizers + unsigned int coded_lossless : 1; // 1: all segments use lossless coding + unsigned int use_superres : 1; // 1: superres enabled for frame + unsigned int tx_mode : 2; // 0: ONLY4x4,1:LARGEST,2:SELECT + unsigned int reference_mode : 1; // 0: SINGLE, 1: SELECT + unsigned int + allow_warped_motion : 1; // 1: allow_warped_motion may be present, 0: + // allow_warped_motion will not be present + unsigned int + reduced_tx_set : 1; // 1: frame is restricted to subset of the full set of + // transform types, 0: no such restriction + unsigned int skip_mode : 1; // 1: most of the mode info is skipped, 0: mode + // info is not skipped + unsigned int reserved1_3bits : 3; // reserved bits; must be set to 0 + + // tiling info + unsigned int + num_tile_cols : 8; // number of tiles across the frame., max is 64 + unsigned int num_tile_rows : 8; // number of tiles down the frame., max is 64 + unsigned int context_update_tile_id : 16; // specifies which tile to use for + // the CDF update + unsigned short tile_widths[64]; // Width of each column in superblocks + unsigned short tile_heights[64]; // height of each row in superblocks + + // CDEF - refer to section 6.10.14 of the AV1 specification Version 1.0.0 with + // Errata 1 + unsigned char cdef_damping_minus_3 : 2; // controls the amount of damping in + // the deringing filter + unsigned char cdef_bits : 2; // the number of bits needed to specify which + // CDEF filter to apply + unsigned char reserved2_4bits : 4; // reserved bits; must be set to 0 + unsigned char + cdef_y_strength[8]; // 0-3 bits: y_pri_strength, 4-7 bits y_sec_strength + unsigned char cdef_uv_strength[8]; // 0-3 bits: uv_pri_strength, 4-7 bits + // uv_sec_strength + + // SkipModeFrames + unsigned char SkipModeFrame0 : 4; // specifies the frames to use for compound + // prediction when skip_mode is equal to 1. + unsigned char SkipModeFrame1 : 4; + + // qp information - refer to section 6.8.11 of the AV1 specification + // Version 1.0.0 with Errata 1 + unsigned char base_qindex; // indicates the base frame qindex. Defined as + // base_q_idx in AV1 specification + char qp_y_dc_delta_q; // indicates the Y DC quantizer relative to base_q_idx. + // Defined as DeltaQYDc in AV1 specification + char qp_u_dc_delta_q; // indicates the U DC quantizer relative to base_q_idx. + // Defined as DeltaQUDc in AV1 specification + char qp_v_dc_delta_q; // indicates the V DC quantizer relative to base_q_idx. + // Defined as DeltaQVDc in AV1 specification + char qp_u_ac_delta_q; // indicates the U AC quantizer relative to base_q_idx. + // Defined as DeltaQUAc in AV1 specification + char qp_v_ac_delta_q; // indicates the V AC quantizer relative to base_q_idx. + // Defined as DeltaQVAc in AV1 specification + unsigned char qm_y; // specifies the level in the quantizer matrix that should + // be used for luma plane decoding + unsigned char qm_u; // specifies the level in the quantizer matrix that should + // be used for chroma U plane decoding + unsigned char qm_v; // specifies the level in the quantizer matrix that should + // be used for chroma V plane decoding + + // segmentation - refer to section 6.8.13 of the AV1 specification + // Version 1.0.0 with Errata 1 + unsigned char segmentation_enabled : 1; // 1 indicates that this frame makes + // use of the segmentation tool + unsigned char + segmentation_update_map : 1; // 1 indicates that the segmentation map are + // updated during the decoding of this frame + unsigned char + segmentation_update_data : 1; // 1 indicates that new parameters are about + // to be specified for each segment + unsigned char + segmentation_temporal_update : 1; // 1 indicates that the updates to the + // segmentation map are coded relative + // to the existing segmentation map + unsigned char reserved3_4bits : 4; // reserved bits; must be set to 0 + short segmentation_feature_data[8][8]; // specifies the feature data for a + // segment feature + unsigned char + segmentation_feature_mask[8]; // indicates that the corresponding feature + // is unused or feature value is coded + + // loopfilter - refer to section 6.8.10 of the AV1 specification Version 1.0.0 + // with Errata 1 + unsigned char loop_filter_level[2]; // contains loop filter strength values + unsigned char loop_filter_level_u; // loop filter strength value of U plane + unsigned char loop_filter_level_v; // loop filter strength value of V plane + unsigned char loop_filter_sharpness; // indicates the sharpness level + char loop_filter_ref_deltas[8]; // contains the adjustment needed for the + // filter level based on the chosen reference + // frame + char loop_filter_mode_deltas[2]; // contains the adjustment needed for the + // filter level based on the chosen mode + unsigned char + loop_filter_delta_enabled : 1; // indicates that the filter level depends + // on the mode and reference frame used to + // predict a block + unsigned char + loop_filter_delta_update : 1; // indicates that additional syntax elements + // are present that specify which mode and + // reference frame deltas are to be updated + unsigned char delta_lf_present : 1; // specifies whether loop filter delta + // values are present in the block level + unsigned char delta_lf_res : 2; // specifies the left shift to apply to the + // decoded loop filter values + unsigned char + delta_lf_multi : 1; // separate loop filter deltas for Hy,Vy,U,V edges + unsigned char reserved4_2bits : 2; // reserved bits; must be set to 0 + + // restoration - refer to section 6.10.15 of the AV1 specification + // Version 1.0.0 with Errata 1 + unsigned char lr_unit_size[3]; // specifies the size of loop restoration + // units: 0: 32, 1: 64, 2: 128, 3: 256 + unsigned char lr_type[3]; // used to compute FrameRestorationType + + // reference frames + unsigned char primary_ref_frame; // specifies which reference frame contains + // the CDF values and other state that should + // be loaded at the start of the frame + unsigned char ref_frame_map[8]; // frames in dpb that can be used as reference + // for current or future frames + + unsigned char temporal_layer_id : 4; // temporal layer id + unsigned char spatial_layer_id : 4; // spatial layer id + + unsigned char reserved5_32bits[4]; // reserved bits; must be set to 0 + + // ref frame list + struct { + unsigned int width; + unsigned int height; + unsigned char index; + unsigned char reserved24Bits[3]; // reserved bits; must be set to 0 + } ref_frame[7]; // frames used as reference frame for current frame. + + // global motion + struct { + unsigned char invalid : 1; + unsigned char wmtype : 2; // defined as GmType in AV1 specification + unsigned char reserved5Bits : 5; // reserved bits; must be set to 0 + char reserved24Bits[3]; // reserved bits; must be set to 0 + int wmmat[6]; // defined as gm_params[] in AV1 specification + } global_motion[7]; // global motion params for reference frames + + // film grain params - refer to section 6.8.20 of the AV1 specification + // Version 1.0.0 with Errata 1 + unsigned short apply_grain : 1; + unsigned short overlap_flag : 1; + unsigned short scaling_shift_minus8 : 2; + unsigned short chroma_scaling_from_luma : 1; + unsigned short ar_coeff_lag : 2; + unsigned short ar_coeff_shift_minus6 : 2; + unsigned short grain_scale_shift : 2; + unsigned short clip_to_restricted_range : 1; + unsigned short reserved6_4bits : 4; // reserved bits; must be set to 0 + unsigned char num_y_points; + unsigned char scaling_points_y[14][2]; + unsigned char num_cb_points; + unsigned char scaling_points_cb[10][2]; + unsigned char num_cr_points; + unsigned char scaling_points_cr[10][2]; + unsigned char reserved7_8bits; // reserved bits; must be set to 0 + unsigned short random_seed; + short ar_coeffs_y[24]; + short ar_coeffs_cb[25]; + short ar_coeffs_cr[25]; + unsigned char cb_mult; + unsigned char cb_luma_mult; + short cb_offset; + unsigned char cr_mult; + unsigned char cr_luma_mult; + short cr_offset; + + int reserved[7]; // reserved bits; must be set to 0 +} CUVIDAV1PICPARAMS; + +/******************************************************************************************/ +//! \struct CUVIDPICPARAMS +//! Picture parameters for decoding +//! This structure is used in cuvidDecodePicture API +//! IN for cuvidDecodePicture +/******************************************************************************************/ +typedef struct _CUVIDPICPARAMS { + int PicWidthInMbs; /**< IN: Coded frame size in macroblocks */ + int FrameHeightInMbs; /**< IN: Coded frame height in macroblocks */ + int CurrPicIdx; /**< IN: Output index of the current picture */ + int field_pic_flag; /**< IN: 0=frame picture, 1=field picture */ + int bottom_field_flag; /**< IN: 0=top field, 1=bottom field (ignored if + field_pic_flag=0) */ + int second_field; /**< IN: Second field of a complementary field pair */ + // Bitstream data + unsigned int + nBitstreamDataLen; /**< IN: Number of bytes in bitstream data buffer */ + const unsigned char* pBitstreamData; /**< IN: Ptr to bitstream data for this + picture (slice-layer) */ + unsigned int nNumSlices; /**< IN: Number of slices in this picture */ + const unsigned int* + pSliceDataOffsets; /**< IN: nNumSlices entries, contains offset of each + slice within the bitstream data buffer */ + int ref_pic_flag; /**< IN: This picture is a reference picture */ + int intra_pic_flag; /**< IN: This picture is entirely intra coded */ + unsigned int Reserved[30]; /**< Reserved for future use */ + + // IN: Codec-specific data + union { + CUVIDMPEG2PICPARAMS mpeg2; /**< Also used for MPEG-1 */ + CUVIDH264PICPARAMS h264; + CUVIDVC1PICPARAMS vc1; + CUVIDMPEG4PICPARAMS mpeg4; + CUVIDJPEGPICPARAMS jpeg; + CUVIDHEVCPICPARAMS hevc; + CUVIDVP8PICPARAMS vp8; + CUVIDVP9PICPARAMS vp9; + CUVIDAV1PICPARAMS av1; + unsigned int CodecReserved[1024]; + } CodecSpecific; +} CUVIDPICPARAMS; + +/******************************************************/ +//! \struct CUVIDPROCPARAMS +//! Picture parameters for postprocessing +//! This structure is used in cuvidMapVideoFrame API +/******************************************************/ +typedef struct _CUVIDPROCPARAMS { + int progressive_frame; /**< IN: Input is progressive (deinterlace_mode will be + ignored) */ + int second_field; /**< IN: Output the second field (ignored if deinterlace + mode is Weave) */ + int top_field_first; /**< IN: Input frame is top field first (1st field is + top, 2nd field is bottom) */ + int unpaired_field; /**< IN: Input only contains one field (2nd field is + invalid) */ + // The fields below are used for raw YUV input + unsigned int reserved_flags; /**< Reserved for future use (set to zero) */ + unsigned int reserved_zero; /**< Reserved (set to zero) */ + unsigned long long + raw_input_dptr; /**< IN: Input CUdeviceptr for raw YUV extensions */ + unsigned int raw_input_pitch; /**< IN: pitch in bytes of raw YUV input (should + be aligned appropriately) */ + unsigned int + raw_input_format; /**< IN: Input YUV format (cudaVideoCodec_enum) */ + unsigned long long + raw_output_dptr; /**< IN: Output CUdeviceptr for raw YUV extensions */ + unsigned int raw_output_pitch; /**< IN: pitch in bytes of raw YUV output + (should be aligned appropriately) */ + unsigned int Reserved1; /**< Reserved for future use (set to zero) */ + CUstream output_stream; /**< IN: stream object used by cuvidMapVideoFrame */ + unsigned int Reserved[46]; /**< Reserved for future use (set to zero) */ + unsigned long long* + histogram_dptr; /**< OUT: Output CUdeviceptr for histogram extensions */ + void* Reserved2[1]; /**< Reserved for future use (set to zero) */ +} CUVIDPROCPARAMS; + +/*********************************************************************************************************/ +//! \struct CUVIDGETDECODESTATUS +//! Struct for reporting decode status. +//! This structure is used in cuvidGetDecodeStatus API. +/*********************************************************************************************************/ +typedef struct _CUVIDGETDECODESTATUS { + cuvidDecodeStatus decodeStatus; + unsigned int reserved[31]; + void* pReserved[8]; +} CUVIDGETDECODESTATUS; + +/****************************************************/ +//! \struct CUVIDRECONFIGUREDECODERINFO +//! Struct for decoder reset +//! This structure is used in cuvidReconfigureDecoder() API +/****************************************************/ +typedef struct _CUVIDRECONFIGUREDECODERINFO { + unsigned int ulWidth; /**< IN: Coded sequence width in pixels, MUST be < = + ulMaxWidth defined at CUVIDDECODECREATEINFO */ + unsigned int ulHeight; /**< IN: Coded sequence height in pixels, MUST be < = + ulMaxHeight defined at CUVIDDECODECREATEINFO */ + unsigned int ulTargetWidth; /**< IN: Post processed output width */ + unsigned int ulTargetHeight; /**< IN: Post Processed output height */ + unsigned int ulNumDecodeSurfaces; /**< IN: Maximum number of internal decode + surfaces */ + unsigned int reserved1[12]; /**< Reserved for future use. Set to Zero */ + + /** + * IN: Area of frame to be displayed. Use-case : Source Cropping + */ + struct { + short left; + short top; + short right; + short bottom; + } display_area; + + /** + * IN: Target Rectangle in the OutputFrame. Use-case : Aspect ratio Conversion + */ + struct { + short left; + short top; + short right; + short bottom; + } target_rect; + + unsigned int reserved2[11]; /**< Reserved for future use. Set to Zero */ +} CUVIDRECONFIGUREDECODERINFO; + +/***********************************************************************************************************/ +//! VIDEO_DECODER +//! +//! In order to minimize decode latencies, there should be always at least 2 +//! pictures in the decode queue at any time, in order to make sure that all +//! decode engines are always busy. +//! +//! Overall data flow: +//! - cuvidGetDecoderCaps(...) +//! - cuvidCreateDecoder(...) +//! - For each picture: +//! + cuvidDecodePicture(N) +//! + cuvidMapVideoFrame(N-4) +//! + do some processing in cuda +//! + cuvidUnmapVideoFrame(N-4) +//! + cuvidDecodePicture(N+1) +//! + cuvidMapVideoFrame(N-3) +//! + ... +//! - cuvidDestroyDecoder(...) +//! +//! NOTE: +//! - When the cuda context is created from a D3D device, the D3D device must +//! also be created +//! with the D3DCREATE_MULTITHREADED flag. +//! - There is a limit to how many pictures can be mapped simultaneously +//! (ulNumOutputSurfaces) +//! - cuvidDecodePicture may block the calling thread if there are too many +//! pictures pending +//! in the decode queue +/***********************************************************************************************************/ + +/**********************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidGetDecoderCaps(CUVIDDECODECAPS *pdc) +//! Queries decode capabilities of NVDEC-HW based on CodecType, ChromaFormat and +//! BitDepthMinus8 parameters. +//! 1. Application fills IN parameters CodecType, ChromaFormat and +//! BitDepthMinus8 of CUVIDDECODECAPS structure +//! 2. On calling cuvidGetDecoderCaps, driver fills OUT parameters if the IN +//! parameters are supported +//! If IN parameters passed to the driver are not supported by NVDEC-HW, then +//! all OUT params are set to 0. +//! E.g. on Geforce GTX 960: +//! App fills - eCodecType = cudaVideoCodec_H264; eChromaFormat = +//! cudaVideoChromaFormat_420; nBitDepthMinus8 = 0; Given IN parameters are +//! supported, hence driver fills: bIsSupported = 1; nMinWidth = 48; +//! nMinHeight = 16; nMaxWidth = 4096; nMaxHeight = 4096; nMaxMBCount = +//! 65536; +//! CodedWidth*CodedHeight/256 must be less than or equal to nMaxMBCount +/**********************************************************************************************************************/ +extern CUresult CUDAAPI cuvidGetDecoderCaps(CUVIDDECODECAPS* pdc); + +/*****************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidCreateDecoder(CUvideodecoder *phDecoder, +//! CUVIDDECODECREATEINFO *pdci) Create the decoder object based on pdci. A +//! handle to the created decoder is returned +/*****************************************************************************************************/ +extern CUresult CUDAAPI +cuvidCreateDecoder(CUvideodecoder* phDecoder, CUVIDDECODECREATEINFO* pdci); + +/*****************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidDestroyDecoder(CUvideodecoder hDecoder) +//! Destroy the decoder object +/*****************************************************************************************************/ +extern CUresult CUDAAPI cuvidDestroyDecoder(CUvideodecoder hDecoder); + +/*****************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidDecodePicture(CUvideodecoder hDecoder, +//! CUVIDPICPARAMS *pPicParams) Decode a single picture (field or frame) Kicks +//! off HW decoding +/*****************************************************************************************************/ +extern CUresult CUDAAPI +cuvidDecodePicture(CUvideodecoder hDecoder, CUVIDPICPARAMS* pPicParams); + +/************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidGetDecodeStatus(CUvideodecoder hDecoder, int +//! nPicIdx); Get the decode status for frame corresponding to nPicIdx API is +//! supported for Maxwell and above generation GPUs. API is currently supported +//! for HEVC, H264 and JPEG codecs. API returns CUDA_ERROR_NOT_SUPPORTED error +//! code for unsupported GPU or codec. +/************************************************************************************************************/ +extern CUresult CUDAAPI cuvidGetDecodeStatus( + CUvideodecoder hDecoder, + int nPicIdx, + CUVIDGETDECODESTATUS* pDecodeStatus); + +/*********************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidReconfigureDecoder(CUvideodecoder hDecoder, +//! CUVIDRECONFIGUREDECODERINFO *pDecReconfigParams) Used to reuse single +//! decoder for multiple clips. Currently supports resolution change, resize +//! params, display area params, target area params change for same codec. Must +//! be called during CUVIDPARSERPARAMS::pfnSequenceCallback +/*********************************************************************************************************/ +extern CUresult CUDAAPI cuvidReconfigureDecoder( + CUvideodecoder hDecoder, + CUVIDRECONFIGUREDECODERINFO* pDecReconfigParams); + +#if !defined(__CUVID_DEVPTR64) || defined(__CUVID_INTERNAL) +/************************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidMapVideoFrame(CUvideodecoder hDecoder, int +//! nPicIdx, unsigned int *pDevPtr, +//! unsigned int *pPitch, +//! CUVIDPROCPARAMS *pVPP); +//! Post-process and map video frame corresponding to nPicIdx for use in cuda. +//! Returns cuda device pointer and associated pitch of the video frame +/************************************************************************************************************************/ +extern CUresult CUDAAPI cuvidMapVideoFrame( + CUvideodecoder hDecoder, + int nPicIdx, + unsigned int* pDevPtr, + unsigned int* pPitch, + CUVIDPROCPARAMS* pVPP); + +/*****************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidUnmapVideoFrame(CUvideodecoder hDecoder, unsigned +//! int DevPtr) Unmap a previously mapped video frame +/*****************************************************************************************************/ +extern CUresult CUDAAPI +cuvidUnmapVideoFrame(CUvideodecoder hDecoder, unsigned int DevPtr); +#endif + +/****************************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidMapVideoFrame64(CUvideodecoder hDecoder, int +//! nPicIdx, unsigned long long *pDevPtr, +//! unsigned int * pPitch, +//! CUVIDPROCPARAMS *pVPP); +//! Post-process and map video frame corresponding to nPicIdx for use in cuda. +//! Returns cuda device pointer and associated pitch of the video frame +/****************************************************************************************************************************/ +extern CUresult CUDAAPI cuvidMapVideoFrame64( + CUvideodecoder hDecoder, + int nPicIdx, + unsigned long long* pDevPtr, + unsigned int* pPitch, + CUVIDPROCPARAMS* pVPP); + +/**************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidUnmapVideoFrame64(CUvideodecoder hDecoder, +//! unsigned long long DevPtr); Unmap a previously mapped video frame +/**************************************************************************************************/ +extern CUresult CUDAAPI +cuvidUnmapVideoFrame64(CUvideodecoder hDecoder, unsigned long long DevPtr); + +#if defined(__CUVID_DEVPTR64) && !defined(__CUVID_INTERNAL) +#define cuvidMapVideoFrame cuvidMapVideoFrame64 +#define cuvidUnmapVideoFrame cuvidUnmapVideoFrame64 +#endif + +/********************************************************************************************************************/ +//! +//! Context-locking: to facilitate multi-threaded implementations, the following +//! 4 functions provide a simple mutex-style host synchronization. If a non-NULL +//! context is specified in CUVIDDECODECREATEINFO, the codec library will +//! acquire the mutex associated with the given context before making any cuda +//! calls. A multi-threaded application could create a lock associated with a +//! context handle so that multiple threads can safely share the same cuda +//! context: +//! - use cuCtxPopCurrent immediately after context creation in order to create +//! a 'floating' context +//! that can be passed to cuvidCtxLockCreate. +//! - When using a floating context, all cuda calls should only be made within +//! a cuvidCtxLock/cuvidCtxUnlock section. +//! +//! NOTE: This is a safer alternative to cuCtxPushCurrent and cuCtxPopCurrent, +//! and is not related to video decoder in any way (implemented as a critical +//! section associated with cuCtx{Push|Pop}Current calls). +/********************************************************************************************************************/ + +/********************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidCtxLockCreate(CUvideoctxlock *pLock, CUcontext +//! ctx) This API is used to create CtxLock object +/********************************************************************************************************************/ +extern CUresult CUDAAPI +cuvidCtxLockCreate(CUvideoctxlock* pLock, CUcontext ctx); + +/********************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidCtxLockDestroy(CUvideoctxlock lck) +//! This API is used to free CtxLock object +/********************************************************************************************************************/ +extern CUresult CUDAAPI cuvidCtxLockDestroy(CUvideoctxlock lck); + +/********************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidCtxLock(CUvideoctxlock lck, unsigned int +//! reserved_flags) This API is used to acquire ctxlock +/********************************************************************************************************************/ +extern CUresult CUDAAPI +cuvidCtxLock(CUvideoctxlock lck, unsigned int reserved_flags); + +/********************************************************************************************************************/ +//! \fn CUresult CUDAAPI cuvidCtxUnlock(CUvideoctxlock lck, unsigned int +//! reserved_flags) This API is used to release ctxlock +/********************************************************************************************************************/ +extern CUresult CUDAAPI +cuvidCtxUnlock(CUvideoctxlock lck, unsigned int reserved_flags); + +/**********************************************************************************************/ + +#if defined(__cplusplus) +} + +// Auto-lock helper for C++ applications +class CCtxAutoLock { + private: + CUvideoctxlock m_ctx; + + public: + CCtxAutoLock(CUvideoctxlock ctx) : m_ctx(ctx) { + cuvidCtxLock(m_ctx, 0); + } + + ~CCtxAutoLock() { + cuvidCtxUnlock(m_ctx, 0); + } +}; +#endif /* __cplusplus */ + +#endif // __CUDA_VIDEO_H__ diff --git a/src/torchcodec/_core/nvcuvid_include/nvcuvid.h b/src/torchcodec/_core/nvcuvid_include/nvcuvid.h new file mode 100644 index 000000000..f0d9446d7 --- /dev/null +++ b/src/torchcodec/_core/nvcuvid_include/nvcuvid.h @@ -0,0 +1,610 @@ +/* + * This copyright notice applies to this header file only: + * + * Copyright (c) 2010-2024 NVIDIA Corporation + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the software, and to permit persons to whom the + * software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/********************************************************************************************************************/ +//! \file nvcuvid.h +//! NVDECODE API provides video decoding interface to NVIDIA GPU devices. +//! \date 2015-2024 +//! This file contains the interface constants, structure definitions and +//! function prototypes. +/********************************************************************************************************************/ + +#if !defined(__NVCUVID_H__) +#define __NVCUVID_H__ + +#include "cuviddec.h" + +#if defined(__cplusplus) +extern "C" { +#endif /* __cplusplus */ + +#define MAX_CLOCK_TS 3 + +/***********************************************/ +//! +//! High-level helper APIs for video sources +//! +/***********************************************/ + +typedef void* CUvideosource; +typedef void* CUvideoparser; +typedef long long CUvideotimestamp; + +/************************************************************************/ +//! \enum cudaVideoState +//! Video source state enums +//! Used in cuvidSetVideoSourceState and cuvidGetVideoSourceState APIs +/************************************************************************/ +typedef enum { + cudaVideoState_Error = -1, /**< Error state (invalid source) */ + cudaVideoState_Stopped = + 0, /**< Source is stopped (or reached end-of-stream) */ + cudaVideoState_Started = 1 /**< Source is running and delivering data */ +} cudaVideoState; + +/************************************************************************/ +//! \enum cudaAudioCodec +//! Audio compression enums +//! Used in CUAUDIOFORMAT structure +/************************************************************************/ +typedef enum { + cudaAudioCodec_MPEG1 = 0, /**< MPEG-1 Audio */ + cudaAudioCodec_MPEG2, /**< MPEG-2 Audio */ + cudaAudioCodec_MP3, /**< MPEG-1 Layer III Audio */ + cudaAudioCodec_AC3, /**< Dolby Digital (AC3) Audio */ + cudaAudioCodec_LPCM, /**< PCM Audio */ + cudaAudioCodec_AAC, /**< AAC Audio */ +} cudaAudioCodec; + +/************************************************************************/ +//! \ingroup STRUCTS +//! \struct TIMECODESET +//! Used to store Time code set extracted from H264 and HEVC codecs +/************************************************************************/ +typedef struct _TIMECODESET { + unsigned int time_offset_value; + unsigned short n_frames; + unsigned char clock_timestamp_flag; + unsigned char units_field_based_flag; + unsigned char counting_type; + unsigned char full_timestamp_flag; + unsigned char discontinuity_flag; + unsigned char cnt_dropped_flag; + unsigned char seconds_value; + unsigned char minutes_value; + unsigned char hours_value; + unsigned char seconds_flag; + unsigned char minutes_flag; + unsigned char hours_flag; + unsigned char time_offset_length; + unsigned char reserved; +} TIMECODESET; + +/************************************************************************/ +//! \ingroup STRUCTS +//! \struct TIMECODE +//! Used to extract Time code in H264 and HEVC codecs +/************************************************************************/ +typedef struct _TIMECODE { + TIMECODESET time_code_set[MAX_CLOCK_TS]; + unsigned char num_clock_ts; +} TIMECODE; + +/**********************************************************************************/ +//! \ingroup STRUCTS +//! \struct SEIMASTERINGDISPLAYINFO +//! Used to extract mastering display color volume SEI in H264 and HEVC codecs +/**********************************************************************************/ +typedef struct _SEIMASTERINGDISPLAYINFO { + unsigned short display_primaries_x[3]; + unsigned short display_primaries_y[3]; + unsigned short white_point_x; + unsigned short white_point_y; + unsigned int max_display_mastering_luminance; + unsigned int min_display_mastering_luminance; +} SEIMASTERINGDISPLAYINFO; + +/**********************************************************************************/ +//! \ingroup STRUCTS +//! \struct SEICONTENTLIGHTLEVELINFO +//! Used to extract content light level info SEI in H264 and HEVC codecs +/**********************************************************************************/ +typedef struct _SEICONTENTLIGHTLEVELINFO { + unsigned short max_content_light_level; + unsigned short max_pic_average_light_level; + unsigned int reserved; +} SEICONTENTLIGHTLEVELINFO; + +/**********************************************************************************/ +//! \ingroup STRUCTS +//! \struct TIMECODEMPEG2 +//! Used to extract Time code in MPEG2 codec +/**********************************************************************************/ +typedef struct _TIMECODEMPEG2 { + unsigned char drop_frame_flag; + unsigned char time_code_hours; + unsigned char time_code_minutes; + unsigned char marker_bit; + unsigned char time_code_seconds; + unsigned char time_code_pictures; +} TIMECODEMPEG2; + +/**********************************************************************************/ +//! \ingroup STRUCTS +//! \struct SEIALTERNATIVETRANSFERCHARACTERISTICS +//! Used to extract alternative transfer characteristics SEI in H264 and HEVC +//! codecs +/**********************************************************************************/ +typedef struct _SEIALTERNATIVETRANSFERCHARACTERISTICS { + unsigned char preferred_transfer_characteristics; +} SEIALTERNATIVETRANSFERCHARACTERISTICS; + +/**********************************************************************************/ +//! \ingroup STRUCTS +//! \struct CUSEIMESSAGE; +//! Used in CUVIDSEIMESSAGEINFO structure +/**********************************************************************************/ +typedef struct _CUSEIMESSAGE { + unsigned char sei_message_type; /**< OUT: SEI Message Type */ + unsigned char reserved[3]; + unsigned int sei_message_size; /**< OUT: SEI Message Size */ +} CUSEIMESSAGE; + +/************************************************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDEOFORMAT +//! Video format +//! Used in cuvidGetSourceVideoFormat API +/************************************************************************************************/ +typedef struct { + cudaVideoCodec codec; /**< OUT: Compression format */ + + /** + * OUT: frame rate = numerator / denominator (for example: 30000/1001) + */ + struct { + /**< OUT: frame rate numerator (0 = unspecified or variable frame rate) */ + unsigned int numerator; + /**< OUT: frame rate denominator (0 = unspecified or variable frame rate) */ + unsigned int denominator; + } frame_rate; + + unsigned char progressive_sequence; /**< OUT: 0=interlaced, 1=progressive */ + unsigned char bit_depth_luma_minus8; /**< OUT: high bit depth luma. E.g, 2 for + 10-bitdepth, 4 for 12-bitdepth */ + unsigned char + bit_depth_chroma_minus8; /**< OUT: high bit depth chroma. E.g, 2 for + 10-bitdepth, 4 for 12-bitdepth */ + unsigned char + min_num_decode_surfaces; /**< OUT: Minimum number of decode surfaces to be + allocated for correct decoding. The client can + send this value in ulNumDecodeSurfaces (in + CUVIDDECODECREATEINFO structure). This + guarantees correct functionality and optimal + video memory usage but not necessarily the + best performance, which depends on the design + of the overall application. The optimal number + of decode surfaces (in terms of performance + and memory utilization) should be decided by + experimentation for each application, but it + cannot go below + min_num_decode_surfaces. If this value is used + for ulNumDecodeSurfaces then it must be + returned to parser during sequence + callback. */ + unsigned int coded_width; /**< OUT: coded frame width in pixels */ + unsigned int coded_height; /**< OUT: coded frame height in pixels */ + + /** + * area of the frame that should be displayed + * typical example: + * coded_width = 1920, coded_height = 1088 + * display_area = { 0,0,1920,1080 } + */ + struct { + int left; /**< OUT: left position of display rect */ + int top; /**< OUT: top position of display rect */ + int right; /**< OUT: right position of display rect */ + int bottom; /**< OUT: bottom position of display rect */ + } display_area; + + cudaVideoChromaFormat chroma_format; /**< OUT: Chroma format */ + unsigned int bitrate; /**< OUT: video bitrate (bps, 0=unknown) */ + + /** + * OUT: Display Aspect Ratio = x:y (4:3, 16:9, etc) + */ + struct { + int x; + int y; + } display_aspect_ratio; + + /** + * Video Signal Description + * Refer section E.2.1 (VUI parameters semantics) of H264 spec file + */ + struct { + unsigned char video_format : 3; /**< OUT: 0-Component, 1-PAL, 2-NTSC, + 3-SECAM, 4-MAC, 5-Unspecified */ + unsigned char video_full_range_flag : 1; /**< OUT: indicates the black level + and luma and chroma range */ + unsigned char reserved_zero_bits : 4; /**< Reserved bits */ + unsigned char color_primaries; /**< OUT: chromaticity coordinates of source + primaries */ + unsigned char + transfer_characteristics; /**< OUT: opto-electronic transfer + characteristic of the source picture */ + unsigned char matrix_coefficients; /**< OUT: used in deriving luma and + chroma signals from RGB primaries */ + } video_signal_description; + + unsigned int seqhdr_data_length; /**< OUT: Additional bytes following + (CUVIDEOFORMATEX) */ +} CUVIDEOFORMAT; + +/****************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDOPERATINGPOINTINFO +//! Operating point information of scalable bitstream +/****************************************************************/ +typedef struct { + cudaVideoCodec codec; + + union { + struct { + unsigned char operating_points_cnt; + unsigned char reserved24_bits[3]; + unsigned short operating_points_idc[32]; + } av1; + + unsigned char CodecReserved[1024]; + }; +} CUVIDOPERATINGPOINTINFO; + +/**********************************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDSEIMESSAGEINFO +//! Used in cuvidParseVideoData API with PFNVIDSEIMSGCALLBACK pfnGetSEIMsg +/**********************************************************************************/ +typedef struct _CUVIDSEIMESSAGEINFO { + void* pSEIData; /**< OUT: SEI Message Data */ + CUSEIMESSAGE* pSEIMessage; /**< OUT: SEI Message Info */ + unsigned int sei_message_count; /**< OUT: SEI Message Count */ + unsigned int picIdx; /**< OUT: SEI Message Pic Index */ +} CUVIDSEIMESSAGEINFO; + +/****************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDAV1SEQHDR +//! AV1 specific sequence header information +/****************************************************************/ +typedef struct { + unsigned int max_width; + unsigned int max_height; + unsigned char reserved[1016]; +} CUVIDAV1SEQHDR; + +/****************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDEOFORMATEX +//! Video format including raw sequence header information +//! Used in cuvidGetSourceVideoFormat API +/****************************************************************/ +typedef struct { + CUVIDEOFORMAT format; /**< OUT: CUVIDEOFORMAT structure */ + + union { + CUVIDAV1SEQHDR av1; + unsigned char raw_seqhdr_data[1024]; /**< OUT: Sequence header data */ + }; +} CUVIDEOFORMATEX; + +/****************************************************************/ +//! \ingroup STRUCTS +//! \struct CUAUDIOFORMAT +//! Audio formats +//! Used in cuvidGetSourceAudioFormat API +/****************************************************************/ +typedef struct { + cudaAudioCodec codec; /**< OUT: Compression format */ + unsigned int channels; /**< OUT: number of audio channels */ + unsigned int samplespersec; /**< OUT: sampling frequency */ + unsigned int bitrate; /**< OUT: For uncompressed, can also be used to + determine bits per sample */ + unsigned int reserved1; /**< Reserved for future use */ + unsigned int reserved2; /**< Reserved for future use */ +} CUAUDIOFORMAT; + +/***************************************************************/ +//! \enum CUvideopacketflags +//! Data packet flags +//! Used in CUVIDSOURCEDATAPACKET structure +/***************************************************************/ +typedef enum { + CUVID_PKT_ENDOFSTREAM = + 0x01, /**< Set when this is the last packet for this stream */ + CUVID_PKT_TIMESTAMP = 0x02, /**< Timestamp is valid */ + CUVID_PKT_DISCONTINUITY = + 0x04, /**< Set when a discontinuity has to be signalled */ + CUVID_PKT_ENDOFPICTURE = + 0x08, /**< Set when the packet contains exactly one frame or one field */ + CUVID_PKT_NOTIFY_EOS = + 0x10, /**< If this flag is set along with CUVID_PKT_ENDOFSTREAM, an + additional (dummy) display callback will be invoked with null + value of CUVIDPARSERDISPINFO which should be interpreted as end + of the stream. */ +} CUvideopacketflags; + +/*****************************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDSOURCEDATAPACKET +//! Data Packet +//! Used in cuvidParseVideoData API +//! IN for cuvidParseVideoData +/*****************************************************************************/ +typedef struct _CUVIDSOURCEDATAPACKET { + unsigned long flags; /**< IN: Combination of CUVID_PKT_XXX flags */ + unsigned long payload_size; /**< IN: number of bytes in the payload (may be + zero if EOS flag is set) */ + const unsigned char* payload; /**< IN: Pointer to packet payload data (may be + NULL if EOS flag is set) */ + CUvideotimestamp + timestamp; /**< IN: Presentation time stamp (10MHz clock), only valid if + CUVID_PKT_TIMESTAMP flag is set */ +} CUVIDSOURCEDATAPACKET; + +// Callback for packet delivery +typedef int(CUDAAPI* PFNVIDSOURCECALLBACK)(void*, CUVIDSOURCEDATAPACKET*); + +/**************************************************************************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDSOURCEPARAMS +//! Describes parameters needed in cuvidCreateVideoSource API +//! NVDECODE API is intended for HW accelerated video decoding so CUvideosource +//! doesn't have audio demuxer for all supported containers. It's recommended to +//! clients to use their own or third party demuxer if audio support is needed. +/**************************************************************************************************************************/ +typedef struct _CUVIDSOURCEPARAMS { + unsigned int + ulClockRate; /**< IN: Time stamp units in Hz (0=default=10000000Hz) */ + unsigned int bAnnexb : 1; /**< IN: AV1 annexB stream */ + unsigned int uReserved : 31; /**< Reserved for future use - set to zero */ + unsigned int uReserved1[6]; /**< Reserved for future use - set to zero */ + void* pUserData; /**< IN: User private data passed in to the data handlers */ + PFNVIDSOURCECALLBACK + pfnVideoDataHandler; /**< IN: Called to deliver video packets */ + PFNVIDSOURCECALLBACK + pfnAudioDataHandler; /**< IN: Called to deliver audio packets. */ + void* pvReserved2[8]; /**< Reserved for future use - set to NULL */ +} CUVIDSOURCEPARAMS; + +/**********************************************/ +//! \ingroup ENUMS +//! \enum CUvideosourceformat_flags +//! CUvideosourceformat_flags +//! Used in cuvidGetSourceVideoFormat API +/**********************************************/ +typedef enum { + CUVID_FMT_EXTFORMATINFO = + 0x100 /**< Return extended format structure (CUVIDEOFORMATEX) */ +} CUvideosourceformat_flags; + +#if !defined(__APPLE__) +/***************************************************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidCreateVideoSource(CUvideosource *pObj, const char +//! *pszFileName, CUVIDSOURCEPARAMS *pParams) Create CUvideosource object. +//! CUvideosource spawns demultiplexer thread that provides two callbacks: +//! pfnVideoDataHandler() and pfnAudioDataHandler() +//! NVDECODE API is intended for HW accelerated video decoding so CUvideosource +//! doesn't have audio demuxer for all supported containers. It's recommended to +//! clients to use their own or third party demuxer if audio support is needed. +/***************************************************************************************************************************/ +CUresult CUDAAPI cuvidCreateVideoSource( + CUvideosource* pObj, + const char* pszFileName, + CUVIDSOURCEPARAMS* pParams); + +/***************************************************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidCreateVideoSourceW(CUvideosource *pObj, const +//! wchar_t *pwszFileName, CUVIDSOURCEPARAMS *pParams) Create video source +/***************************************************************************************************************************/ +CUresult CUDAAPI cuvidCreateVideoSourceW( + CUvideosource* pObj, + const wchar_t* pwszFileName, + CUVIDSOURCEPARAMS* pParams); + +/********************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidDestroyVideoSource(CUvideosource obj) +//! Destroy video source +/********************************************************************/ +CUresult CUDAAPI cuvidDestroyVideoSource(CUvideosource obj); + +/******************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidSetVideoSourceState(CUvideosource obj, +//! cudaVideoState state) Set video source state to: cudaVideoState_Started - to +//! signal the source to run and deliver data cudaVideoState_Stopped - to stop +//! the source from delivering the data cudaVideoState_Error - invalid source +/******************************************************************************************/ +CUresult CUDAAPI +cuvidSetVideoSourceState(CUvideosource obj, cudaVideoState state); + +/******************************************************************************************/ +//! \ingroup FUNCTS +//! \fn cudaVideoState CUDAAPI cuvidGetVideoSourceState(CUvideosource obj) +//! Get video source state +//! Returns: +//! cudaVideoState_Started - if Source is running and delivering data +//! cudaVideoState_Stopped - if Source is stopped or reached end-of-stream +//! cudaVideoState_Error - if Source is in error state +/******************************************************************************************/ +cudaVideoState CUDAAPI cuvidGetVideoSourceState(CUvideosource obj); + +/******************************************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidGetSourceVideoFormat(CUvideosource obj, +//! CUVIDEOFORMAT *pvidfmt, unsigned int flags) Gets video source format in +//! pvidfmt, flags is set to combination of CUvideosourceformat_flags as per +//! requirement +/******************************************************************************************************************/ +CUresult CUDAAPI cuvidGetSourceVideoFormat( + CUvideosource obj, + CUVIDEOFORMAT* pvidfmt, + unsigned int flags); + +/**************************************************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidGetSourceAudioFormat(CUvideosource obj, +//! CUAUDIOFORMAT *paudfmt, unsigned int flags) Get audio source format NVDECODE +//! API is intended for HW accelerated video decoding so CUvideosource doesn't +//! have audio demuxer for all supported containers. It's recommended to clients +//! to use their own or third party demuxer if audio support is needed. +/**************************************************************************************************************************/ +CUresult CUDAAPI cuvidGetSourceAudioFormat( + CUvideosource obj, + CUAUDIOFORMAT* paudfmt, + unsigned int flags); + +#endif +/**********************************************************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDPARSERDISPINFO +//! Used in cuvidParseVideoData API with PFNVIDDISPLAYCALLBACK pfnDisplayPicture +/**********************************************************************************/ +typedef struct _CUVIDPARSERDISPINFO { + int picture_index; /**< OUT: Index of the current picture */ + int progressive_frame; /**< OUT: 1 if progressive frame; 0 otherwise */ + int top_field_first; /**< OUT: 1 if top field is displayed first; 0 otherwise + */ + int repeat_first_field; /**< OUT: Number of additional fields (1=ivtc, 2=frame + doubling, 4=frame tripling, -1=unpaired field) */ + CUvideotimestamp timestamp; /**< OUT: Presentation time stamp */ +} CUVIDPARSERDISPINFO; + +/***********************************************************************************************************************/ +//! Parser callbacks +//! The parser will call these synchronously from within cuvidParseVideoData(), +//! whenever there is sequence change or a picture is ready to be decoded and/or +//! displayed. First argument in functions is "void *pUserData" member of +//! structure CUVIDSOURCEPARAMS Return values from these callbacks are +//! interpreted as below. If the callbacks return failure, it will be propagated +//! by cuvidParseVideoData() to the application. Parser picks default operating +//! point as 0 and outputAllLayers flag as 0 if PFNVIDOPPOINTCALLBACK is not set +//! or return value is -1 or invalid operating point. PFNVIDSEQUENCECALLBACK : +//! 0: fail, 1: succeeded, > 1: override dpb size of parser (set by +//! CUVIDPARSERPARAMS::ulMaxNumDecodeSurfaces while creating parser) +//! PFNVIDDECODECALLBACK : 0: fail, >=1: succeeded +//! PFNVIDDISPLAYCALLBACK : 0: fail, >=1: succeeded +//! PFNVIDOPPOINTCALLBACK : <0: fail, >=0: succeeded (bit 0-9: OperatingPoint, +//! bit 10-10: outputAllLayers, bit 11-30: reserved) PFNVIDSEIMSGCALLBACK : 0: +//! fail, >=1: succeeded +/***********************************************************************************************************************/ +typedef int(CUDAAPI* PFNVIDSEQUENCECALLBACK)(void*, CUVIDEOFORMAT*); +typedef int(CUDAAPI* PFNVIDDECODECALLBACK)(void*, CUVIDPICPARAMS*); +typedef int(CUDAAPI* PFNVIDDISPLAYCALLBACK)(void*, CUVIDPARSERDISPINFO*); +typedef int(CUDAAPI* PFNVIDOPPOINTCALLBACK)(void*, CUVIDOPERATINGPOINTINFO*); +typedef int(CUDAAPI* PFNVIDSEIMSGCALLBACK)(void*, CUVIDSEIMESSAGEINFO*); + +/**************************************/ +//! \ingroup STRUCTS +//! \struct CUVIDPARSERPARAMS +//! Used in cuvidCreateVideoParser API +/**************************************/ +typedef struct _CUVIDPARSERPARAMS { + cudaVideoCodec CodecType; /**< IN: cudaVideoCodec_XXX */ + unsigned int ulMaxNumDecodeSurfaces; /**< IN: Max # of decode surfaces (parser + will cycle through these) */ + unsigned int + ulClockRate; /**< IN: Timestamp units in Hz (0=default=10000000Hz) */ + unsigned int ulErrorThreshold; /**< IN: % Error threshold (0-100) for calling + pfnDecodePicture (100=always IN: call + pfnDecodePicture even if picture bitstream + is fully corrupted) */ + unsigned int ulMaxDisplayDelay; /**< IN: Max display queue delay (improves + pipelining of decode with display) 0=no + delay (recommended values: 2..4) */ + unsigned int bAnnexb : 1; /**< IN: AV1 annexB stream */ + unsigned int uReserved : 31; /**< Reserved for future use - set to zero */ + unsigned int uReserved1[4]; /**< IN: Reserved for future use - set to 0 */ + void* pUserData; /**< IN: User data for callbacks */ + PFNVIDSEQUENCECALLBACK + pfnSequenceCallback; /**< IN: Called before decoding frames and/or + whenever there is a fmt change */ + PFNVIDDECODECALLBACK pfnDecodePicture; /**< IN: Called when a picture is ready + to be decoded (decode order) */ + PFNVIDDISPLAYCALLBACK + pfnDisplayPicture; /**< IN: Called whenever a picture is ready to be + displayed (display order) */ + PFNVIDOPPOINTCALLBACK + pfnGetOperatingPoint; /**< IN: Called from AV1 sequence header to get + operating point of a AV1 scalable bitstream */ + PFNVIDSEIMSGCALLBACK pfnGetSEIMsg; /**< IN: Called when all SEI messages are + parsed for particular frame */ + void* pvReserved2[5]; /**< Reserved for future use - set to NULL */ + CUVIDEOFORMATEX* pExtVideoInfo; /**< IN: [Optional] sequence header data from + system layer */ +} CUVIDPARSERPARAMS; + +/************************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidCreateVideoParser(CUvideoparser *pObj, +//! CUVIDPARSERPARAMS *pParams) Create video parser object and initialize +/************************************************************************************************/ +CUresult CUDAAPI +cuvidCreateVideoParser(CUvideoparser* pObj, CUVIDPARSERPARAMS* pParams); + +/************************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidParseVideoData(CUvideoparser obj, +//! CUVIDSOURCEDATAPACKET *pPacket) Parse the video data from source data packet +//! in pPacket Extracts parameter sets like SPS, PPS, bitstream etc. from +//! pPacket and calls back pfnDecodePicture with CUVIDPICPARAMS data for kicking +//! of HW decoding calls back pfnSequenceCallback with CUVIDEOFORMAT data for +//! initial sequence header or when the decoder encounters a video format change +//! calls back pfnDisplayPicture with CUVIDPARSERDISPINFO data to display a +//! video frame +/************************************************************************************************/ +CUresult CUDAAPI +cuvidParseVideoData(CUvideoparser obj, CUVIDSOURCEDATAPACKET* pPacket); + +/************************************************************************************************/ +//! \ingroup FUNCTS +//! \fn CUresult CUDAAPI cuvidDestroyVideoParser(CUvideoparser obj) +//! Destroy the video parser +/************************************************************************************************/ +CUresult CUDAAPI cuvidDestroyVideoParser(CUvideoparser obj); + +/**********************************************************************************************/ + +#if defined(__cplusplus) +} +#endif /* __cplusplus */ + +#endif // __NVCUVID_H__ diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index d618b8d9f..40486dcb5 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -276,6 +276,7 @@ def _add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, + device_variant: str = "default", custom_frame_mappings: Optional[ tuple[torch.Tensor, torch.Tensor, torch.Tensor] ] = None, @@ -294,6 +295,7 @@ def add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: Optional[str] = None, + device_variant: str = "default", custom_frame_mappings: Optional[ tuple[torch.Tensor, torch.Tensor, torch.Tensor] ] = None, diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 3bf7a6ac2..150dd056d 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -141,12 +141,31 @@ def __init__( if isinstance(device, torch_device): device = str(device) + # If device looks like "cuda:0:beta", make it "cuda:0" and set + # device_variant to "beta" + # TODONVDEC P2 Consider alternative ways of exposing custom device + # variants, and if we want this new decoder backend to be a "device + # variant" at all. + device_split = device.split(":") + if len(device_split) == 3: + device_variant = device_split[2] + device = ":".join(device_split[0:2]) + else: + device_variant = "default" + + # TODONVDEC P0 Support approximate mode. Not ideal to validate that here + # either, but validating this at a lower level forces to add yet another + # (temprorary) validation API to the device inteface + if device_variant == "beta" and seek_mode != "exact": + raise ValueError("Seek mode must be exact for BETA CUDA interface.") + core.add_video_stream( self._decoder, stream_index=stream_index, dimension_order=dimension_order, num_threads=num_ffmpeg_threads, device=device, + device_variant=device_variant, custom_frame_mappings=custom_frame_mappings_data, ) diff --git a/test/test_decoders.py b/test/test_decoders.py index e68e4fe6e..dee5e60fc 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -43,6 +43,7 @@ SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, + TEST_SRC_2_720P, ) @@ -1399,6 +1400,64 @@ def test_get_frames_at_tensor_indices(self): decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.int)) decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.float)) + @needs_cuda + @pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE)) + @pytest.mark.parametrize("contiguous_indices", (True, False)) + def test_beta_cuda_interface_get_frame_at(self, asset, contiguous_indices): + ref_decoder = VideoDecoder(asset.path, device="cuda") + beta_decoder = VideoDecoder(asset.path, device="cuda:0:beta") + + assert ref_decoder.metadata == beta_decoder.metadata + + if contiguous_indices: + indices = range(len(ref_decoder)) + else: + indices = range(0, len(ref_decoder), 10) + + for frame_index in indices: + ref_frame = ref_decoder.get_frame_at(frame_index) + beta_frame = beta_decoder.get_frame_at(frame_index) + torch.testing.assert_close(beta_frame.data, ref_frame.data, rtol=0, atol=0) + + assert beta_frame.pts_seconds == ref_frame.pts_seconds + assert beta_frame.duration_seconds == ref_frame.duration_seconds + + @needs_cuda + @pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE)) + @pytest.mark.parametrize("contiguous_indices", (True, False)) + def test_beta_cuda_interface_get_frames_at(self, asset, contiguous_indices): + ref_decoder = VideoDecoder(asset.path, device="cuda") + beta_decoder = VideoDecoder(asset.path, device="cuda:0:beta") + + assert ref_decoder.metadata == beta_decoder.metadata + + if contiguous_indices: + indices = range(len(ref_decoder)) + else: + indices = range(0, len(ref_decoder), 10) + indices = list(indices) + + ref_frames = ref_decoder.get_frames_at(indices) + beta_frames = beta_decoder.get_frames_at(indices) + torch.testing.assert_close(beta_frames.data, ref_frames.data, rtol=0, atol=0) + torch.testing.assert_close(beta_frames.pts_seconds, ref_frames.pts_seconds) + torch.testing.assert_close( + beta_frames.duration_seconds, ref_frames.duration_seconds + ) + + @needs_cuda + def test_beta_cuda_interface_error(self): + with pytest.raises(RuntimeError, match="Can only do H264 for now"): + VideoDecoder(AV1_VIDEO.path, device="cuda:0:beta") + with pytest.raises(RuntimeError, match="Can only do H264 for now"): + VideoDecoder(H265_VIDEO.path, device="cuda:0:beta") + with pytest.raises( + ValueError, match="Seek mode must be exact for BETA CUDA interface." + ): + VideoDecoder(NASA_VIDEO.path, device="cuda:0:beta", seek_mode="approximate") + with pytest.raises(RuntimeError, match="Unsupported device"): + VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant") + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32)) diff --git a/test/utils.py b/test/utils.py index 5b6d9ea76..fc098c5c8 100644 --- a/test/utils.py +++ b/test/utils.py @@ -678,3 +678,13 @@ def sample_format(self) -> str: }, frames={0: {}}, # Not needed for now ) + +# ffmpeg -f lavfi -i testsrc2=duration=2:size=1280x720:rate=30 -c:v libx264 -profile:v baseline -level 3.1 -pix_fmt yuv420p -b:v 2500k -r 30 -movflags +faststart output_720p_2s.mp4 +TEST_SRC_2_720P = TestVideo( + filename="testsrc2.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) From b45deccc356ae7b6c0d22fd85fd9752f4329d42c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 26 Sep 2025 08:34:41 +0100 Subject: [PATCH 02/27] Fixes --- src/torchcodec/decoders/_video_decoder.py | 12 ++++++------ test/resources/testsrc2.mp4 | Bin 0 -> 680196 bytes 2 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 test/resources/testsrc2.mp4 diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 150dd056d..203ba73b7 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -146,12 +146,12 @@ def __init__( # TODONVDEC P2 Consider alternative ways of exposing custom device # variants, and if we want this new decoder backend to be a "device # variant" at all. - device_split = device.split(":") - if len(device_split) == 3: - device_variant = device_split[2] - device = ":".join(device_split[0:2]) - else: - device_variant = "default" + device_variant = "default" + if device is not None: + device_split = device.split(":") + if len(device_split) == 3: + device_variant = device_split[2] + device = ":".join(device_split[0:2]) # TODONVDEC P0 Support approximate mode. Not ideal to validate that here # either, but validating this at a lower level forces to add yet another diff --git a/test/resources/testsrc2.mp4 b/test/resources/testsrc2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4694b453356c6799ba6a3cd2a2960069a4f9682f GIT binary patch literal 680196 zcmbTe2S8I#^C-G0kkA9tYeEyLp$SL}MMMR`jtvM+dT*kT1dyVLU;%6t3t*!tNK*(! z!3rvf3P`{L2tt$=AmqG*{`~3v{@;7|-gn5Jovmkgc4l^F_8b5Jc%R5ap?(n|K>#4Z zN(=`*a4F3%6bPzME=ti z+`~Wp{_G1`6XfMbhRg6lUjI3FGdMdI>}>nlQp?vXF#MlBK(wEi_iuP5-UX*a?8(7i zf!-^1kRCyP!9Fk`8Xfer)4zwU?DeZm)+^k5rOwK=QqCV09*F-1BNZbeJpy64J|Z&W z_kyg{xv{cnaJ+K;-RI8?u9a!q!{AC-xe%fbaQ(Hq+WOkML?S-W&m&fEt--IdUw26e zf&n62%LZXhSAd3q)Im6Y2b{SY5IA_1z{(9p|DL2wMF6z#(~LWCeOxFxA{35SMz9h@ zJtD(@p?3|8?Uew4WJKgo)UUuRbzoe=s}9$NXY{keKlv+VvT&Z~N+Epx?sEnAFpNOM zWpO`Ce%;{w@JDcWJaBLTz$66#BWE}k2Edd6fN?w=(%`feT#kohb2u~sux=PGXTt%W z&pLQPNv8oYuYton02X984dX#|832L|0E#eKDoX&Mvr5;$KOjx*pqUO0wff+`%tb5=M2*Oi0Qa63GlR-A*=R5&hz z^I*IZF2ZS;d=x*!TBr{_4MFO-^j?&bvimRF)`XJ z92OEt4%Q9{Kd7_v6m8$gpg_1zNNA*ANN~hDyeHX%>}gKK!(3#pkN5KS2n_KIFxOp2 zTt~!{gUNx1BD~Fsv2Z$8S63GwDd}zcWbKRA@Tez3Gt~L>F1mEyM zezD$OZYw>mv~UY22Osn{*IkSE^bHRQBD=v&b@7qm-hqLB5in#NYwYD22@kj7iC;E^hBHcnG0{lYZs{eEl7U~w_;}hW>X|AP*kMs?PTSS<{Qx6OY z2_XBzDYxI5`uK=IKTmitzcYyV;P9Uh@$?HKN3IOSFF4XWJdg}mg7ZBBqr%CD+&n{q zLdide1mg&nuztaC7jVUJ@=6__aB`4$1PsM^`yKR+427Wp??Z60xt`%~&BiAGR_+bE5WI7z|x`(9f70YrfV5&WMQg4mCHx`-Q@YhZh-c0ryP~gV%Ir zg#Z1bq4(?3%fOo3ifvM8We&vI$JE)bT^uNG_l3+u)C9#jC1M+u7>fw zKK>+1=cLO?NRBHyNFst8APsd9SRVQ(@0=PBozrz(MBRFysM8A&nnZ64PQLcgdPd4 zDuAv4bSXhh$=@CASEaTwuN%Qc8wYK9{PiBdpV|$5Bl5RVTjgdy`a(>29}A zerOS2#+4H&mznrtkj+@FaHd8JvkOtPNykCz^6IG;ge|^Gr0))D+2Z<-yO@_2^cwF# zn!1W*BYiF^f#&b6&?riki^c6oeN06K(yfFs6$>R5-_n;RR$ndV?PzLDyO~yc+Dq8& z%T+1tR>O#f18u)hYKW7vNry4G~-DQ^CfTNI()b#Lt#B(=lemP5{%XP#?~ zG5aY;3Wa3bgG2eRK}V{EJd@UpD_rSeY?uEMAy`s68%IKX(Y+vxy1fkDL4S-Z3h#b} zr7AMBUhC~Kc+}nE5^zIzr-eJHy`^-1ZApv6rsi>M>CxMMRff+zrah?B;EKtPB>yA9 zen$@lUA50tEc<3X&H&5@#_W=9tl2bZ{=W>H1*6R&CDr1#Cp|;)Hl80m3WS43eF0m&wL0c3rD>dIU*y{Gf%VFo#G=`4M zJ@G@tLgD~4o|(aO$@=@(`K|;X`vgrWP@rtB>bAzT#b{ZhBoij5NDP(Q$LVrnj)jA|6E~0-qca zhkS<-F2@l0sX4FxXNEYC{C{R_xwY>{Pn>~L-|2GF)$2)L299C%E(+OkJfM!DNLFt6K^V^GVt zeJ9DF>fFH}5MyEXCXYmc|9rDAGlH|h=BjoV_%Lf3BU!(5SqQ6G{&=}Gz3=~tlSg= zJc3&#LhbA`L7vpQRCH^9o0S@ziQ3@xY30gUd9D*jf$73~nTpWHyTY^K27Kv!!UPXN zfJgtlx!Z~z1G;qnP{i-Sti*H_oiN9TEO>zjGVoM8sA$9|JZWACAlu}Sm|7PgPQaM$Z3X?j^&+jL0+1^V9DKoW%SWsk9(MU6N?Cby4iqc=a*i z-3=Y_F*=Row+r5U>OJ4C%BN9QKVHncxm&FFK-AA)h&+^y#mC$guyZ@V%R*c*z|} zScQunR6QGA#eqWoZ@$k^fw5thb|EEvm2Cu0lC}FMl+vzh9x)oh1%Z36Q?;v#FBhd+ zZ2(%sE>nDW?dakw>cH*iEA<3)vphtj-==RNu{fn@n+V9Q7b|L=%%}V4z%vaq-KC}w zrp5s`Np;2=!w(H8h?PpUsR&18_m*xcyyS*z+vr(nmw+Kiy9^_wpka$<-;1(XE=6!% zktn$hT#4YlfX&Val3!J4l8l4^zc+yrF1uO4D(KH5{DFQiM{fYc6dvjd0M@Ut*9`q9 zOTzJdS6kQM`q@|{fdh?ZC}+eIBeQL)gvUhWBdTn#A{UapUlEul(}! zf5q_K_p$f2YuB#ESqFv65+TwpNvkeJV@Xe5U;6R-LngA7EFQn=52E{9AGSlzz09F9 zgaGuW>OOMfv`rK<{Ty0s7b5+epp$g`^QPw1WiMH}I%m8Vg3@Bq;%8)W)@JHHV)`sO zDep{@j}5I&|EVlh-`KsfEgs1Rd_V4Kt*WJa@gYPGI{E z`H68(;Q0%Vm$)zcQ(xaL=U9Sn;K@Rc5uex^S71HQ>PQ!oFKX`x7>I$qoake#u~hSw9AREQJ=TTZHvfsU5d@JD_)o< zCr|GmY}Swc#%uj}$v=6G#zveZ(i2ecE0vit0wH?NW$&zKZ5;vmJ?|q_d!zPLU*X@! zZ%j$DY%q&~AP(P&W3c$jBPkQCw~-mcAcLcP&dEV!g$1{*+?<^LVbEV}8r5rmKu1;r zB2|czP*H@pnGfPQEHu@L%B|z~{(Xuh@MzfL^*)uKs4&Ba4h^-!9+eMGGcap$Dn8d0 z^3Gc`1V=hMuQUoX6p~52NX)nxx%yoaC507RDp_> z>QRyfH)qD4BUr%wCuXqo%QP4GT(CHOVn;Xb-^*+vp= zirl_UV^*b|hV-!yru1_s+jr8WbHjJ+V!@*=eU;4}-}%MwRUTrBLzoaxZ|>l>6I+Yp zv3D?gsM7&#cydeE{9LyGvX1o-7PH7jqT`Vo(!zSi#yFFFA^uK?jYynS1GpBhv>||S zwmHEVM(Q`y(csZ`uHOrEz});_@LxcY*$P$clC9jL{PApG#OmN+nR7}JiF6EPRd^L? z3Y~|nXl7|${Tyx|oN_pF_5D@1#>Sc8Nm+p1>Cb)2x5?ZUtqkJ?X)a20mn}>y=;EBX zwpX_c^!SvY_FUqwP{}BY`9^`dk;XO^!j{u}-RSvnZ_scBK4FliQvb0s4knF*u8!Z5 zL!Q(bPEjRTn6x-))n&FU$<{rTKQs|JCTpy>S+L^T4*&{p=Ew?yJ0DZsGD+R?+_uCKuQesfK9fp1~_d=Sl~6eDDFW*>c{s+K{TXMa1RF*ZI9&-KL5_3m6g`E%jVoB z#T^yhI{$MQTj9T7Kc@q2@81vWHp9Y46jdaQJS1FH2a99NX_C8VG9Fz>y||^3hijUF z*z!�IM>6GbpYCIh5MUq{k8$U6qwX%M%f{5vn6YeBnNsN)k&}EEG>zwq^DAk10ih z^t}Z;**tE;LsV>lN{*&Y6T|s7g|5!{xxq}OLClwKY3Sn&rm^H|V8305J}mVa!fK-w zI!s{{a`N2rmyu2~j`u+7=JKRW?6-Z|sEH+rI+SdhCSsIeFVk2+-|SfSBi!5*O7l#O zeUo6G@3pg}LwP9iynD*O*e$D%`z_3B_oNDpyWKwWewT_T)MjQTF4)J2GV8sAv~o`z zdXl8|LoFk9`BYbrTj9behbjO5vu2WtH5Gq0KB&M)fq~v0bAFJ8NV=`EvKz=c)2;oKX@Y$PanUJ5<#N$qrJkXZ)s%AXOoFt= zMf%P1vp@8Dry!P~q*btTc2|A_nW4DX?U1U5JF!UMW&(vZB@x>hhT>m55)x6tGR}I$ zvvWsK6ZCR-%S1gi#Aa!WA^047*+)>x%Bdpa`1C3K5VCS*tF7|0WRa<@jS$94*i7YG zcJ@=Wad^80@^qnUjlCygn?0w+jjRDI*L6^$;Mq{dIP665SK$qAZlmBdF<<-v6^^_FbA>4 zav7JdH&vtyqk*lbQPaKPN0f!J(;nfO7pH9mix&Skh=j44Jr)#Nk49wN@q;*#DyppQ znhg;2S&)~4(6b})Rv+R~ZUKtbt_a)!);<(G3~-QeYKE98D6S7eC?)IX>^xVx+;+aV z&}s667edcpVX9Vrpt}a=NGFFaa*Y6k$e^a`Rt_4wvIX;Jq~Xy5*ZBd_58g$nScmx0 z)k7ffPAlO3c7bQ`y0ui{hM(*zI*A8K8M#<)0*B<{Acs`KW@2%wp~EGH?T<^D3 zEUrTqyKA$ixQh;l4<|22S87@)pRoGLmE|;AtTgWA&;4-V#aKlVGPA3DN) zBFYVEicM?`zc_>;rKS6lM6_eMLP)g(_;=nX=qp zsI{+q-5@5TLh2z9!k&{P0n!?DbC;^IF%pr-P)KE|uS?3}ZpQ#D*tIS-;QyP7f06W> zqg>o3cB$muDq`X2%a02=gi;E;@O0FXp4|aW$;FAP8yG*elmdQ$ziO~CZ(FzmXtV6n zWH9xJ$vn1Le1bTM@9JGD!mtA%n^$R8036bd!$)OKBDe1}T;7QIxiMlC$*p2A6Q+Kt zdrY-?pb|~EMWmyI1C;XWOQgc3Zx`^Y7Nk;#NHr`o{xn-7ERKK+VRvNPPam|prIK-Z z>OFL&a6{bH=M7ex?lRl6R@L}l)`dn5VU=(%M}eMi>l;jACYmebajWxyq`JXXtAdDy z^O2{4JUq}kOb&FRyjNw87>1bPqS;DqqxeHW^5zh155>o{g%EezD>+kuh`IO zjSf;kPhv9F%O>e3llSE|xE9{~dsBf$3j}2l=sOVx84ve1 zA*Hg5>ew1eSZ#d<;_UvHm^q{B=!>uYfl#r>{%!nm3n(*3-0Q5mXP=cZ%#fiwu{h1g z0-YS5_ZOCulHPiNhld|+tn?u`4b1iG<`z7)_2RiWl(n@z^@Wd}Kwptf!BDZ>{|PbV z*MC*b>2+=wnvv$BtMoXAP>@_zunvhOogpPkqyypp3v-I0*xjm43@$biG&|3o6VzHk zbTN!*3jznxyeBkPk4QytFvA19Y%UJHO@&QG6MjG}4+}m)SgHM`Cb7}QH+%^K$yjW< zFPVyJ4?I=Xj0}<5uCKF^iga#;F`ds})EG`kp9=RO7iEw!j4t9?QNDbW4kXW&y*{@#31B)NZLrD1H-5CeZ)dHsCiAGU#tt-UioG@j$#srU zx5QjMb8HZmK7>V?eb93}uFIP+b+0btK4-5I-eH74G1JTs6!nC7CI5b4T$F;GI!8OSR)11u;`ApvB9h;d#Z71q zg9(C}jFfJ{OR^MDR2=zKDls*rFv` zHdyS8q-b3xU0TGg`5@_zZk?4X@E!%rFC5+mzT1TR(2h%8l=jWao@KWnC30{|hp9Y2 zbhn%IMbnnJXL*s+kx5MC#;hTilMST#M8CH@-(CqVq!5`LU>SJY#2)ITOKePuAk;xk zQ%3=*8;8s=7Aq^y;H%|zj4A76Ds{hVl~UaFiy*3ouBGBB*!Z-$kf9>ImBhfIm)kf&Zcz(Cu~{O%iE!-uDms&0&n+ zzP++0e>pP#EYcr4>^Bs7uU1f!Qn*XN1EG|R{lG}nF#xB0LvL*uC_LP3FoB*Mqfn)u zYHk^8u)*^7GinEPAbxMnwbBf1aT0-K?oWS;{E9tl6TvJ*(7AbGD^L^Qe(J!oTy6uV zCz*tV3*)%T2aAy-2#9S%$5$0Hl7xpue1w(AD>`5%NuNMxTs-8NEc!xHQ&)l1$N+9P zy-wxMGhw@rsuhIxD5^`LZb`&(3Bsa(a>NDTYr{8pW#LFl1lZCJxR?iL$iJF1K%;Tb zNGJqkb00I0!SkPBW=EituW(QA`tl-aM-S@KB}&Sto{NE_+f++2U7U~=(;A~dr%BnT zj4U%z@l~g$C`-wgTekXG#(d*|4VDYy;E#dLR{$G~sqxHkloBd~6s=h$oCRB7S_moB zqm6z^>sFM`KbqeEtJ3#tyZocA)?IW4IQcYyLA3#-clFUOZGe95SA5>jkOQ0Ek|(~o zK_Cle;)mZ%_c^nH?tv1$Ae$6Pp+y1g6f+DVjMC@zeWy{{L_E*DHccUxJargdTf!(i z=*Ff9*dz&}zv~y(vCuJ)VHX9Z%^QSUX2kO7hYsZqw7avBTKY#PG;Exv7 zzF#YqwazB@DPAheIPDIjHp6o07Q(_$#m%zj1Dz&wN9N)j#+Daji$O@-zUyOseGvkR zOP7e;<_EHNqaazEq0?-6?hWxvjP0n(-zKlx_0!rSgP!ttdLidW2toRp?qkav!x*G< zr(_D|5rw)8yKDU}u^4c6#-*;L1@3ZKzS)myCU}P~sve?trb*Eww{-PiYJL45H_$I^ zuH64vzJH*V^?>z4*h6VmcmQNVUz)nyf9}V$b4UHCXBz|GOoUmwF;GGcjVQipK_^R= zph61GjTtG7(G;2Y0e%UZU0maiGINrQjw7X~MH- zthukUgkzjmE9Ic5B94M$X-if52*0;7eb-=Ax86q^EeEh-S9P2a=xnRP^N8E5c0*ep zqRD4p3+6YRTQPfIa~uAYYyU@^(279p{Uh?4(rh@=+(a`y{Kr-$dLhAD6>tN7{!K$0 zH~(_J*UGLsscEb#Khhj3nLv8U;&I&3dmWE=cr5TtC{R%#4ITGXLwdS$Al@BqX}jeD zNq$ntpqQ);3%{nHHkuIk25}mt#ts;_zn*5NG-a&E)1W2P()kt_f`M$(;nYlXsl0at{ zSBpYK2;F0E#y#z0(#b3A_7|7^CdR+T-@kc(+#Y55zAs7OKGC3p?xJV3mADD7rCQ5vhj%}BC!slhqg9}*pI8@ zw)G9~B_BqvDvQli`rBRoM`XqQ;$isy2l|876%6O(+_!Otx-`+$^*6K&=q|hRy_rK) zwBM7grvxUy%`)-@Uc^AnEUf>fe+M}IR2EXFxB+ptlIjXi)P?hP@S zYzZtw)xHCvdKZf)V~)#aO30|6v56V^LhYZ#-#^08K2Ms+i0f^a|1?jBk)+uS>xr*)#LE|UhGN+AJ&DXb!~cKisUN?_aa-b|Ug^9$`D2=~ zz;d@-cr zLR`F*%VUt4hv)dyhHg&L3|>A2hPJ{3_YnnsVYNvgLnTR!sN>DujDUh7+~o!8TL&MZ9Aj z>;XD*bc1{AKZh|S+ZbXp5I*0(Cg+eI?AS%qo!@y<5qPlXr-Gb|!J%z`DWbrcir%G< zkPnb!0SjBb+MOCA-;YL0CocZWaVsM1-^3s2ZxQydGB=#3YO`9ZQ)a?%l^$TxFU_YUC zA4A(5cdM_AS2$t2T~^<9v5Q1l(?M`u!{GHSlL2-jhVRacRj?)u?9u!=#$XjdO7WF~ z2SwDgt}0K~x>>?jT|y6Amip|c^Or5?Oa^Em(VXv3jk$?=oJB{TCzKE(&OIT4;tqR+ z$c4}HMv7nsJ$cD1rzmEyi`-^?772maZ-DD{Pb+JmQK5RS?fP^#nifRDJ+DgAfBqi& z$nwXXWH_x_CsVN04_>}LyyV8r_QnGINUdscABDPUL0lWTPFIHzA$N=fyS(g^6N|*2 zm<6v`VqE>QBt@brK2c!gbD4!3P0l10ZBAX8qO_3&>`uayGw=wM!Hy{x z!qGPtDGc1M-8B~)NL5V9UAr-T4b}gg6_HN76X>K1OmEs9z~R?Q283h2Hj{x0)b+)o zwOEG?c!ADdB+`)$`oz@}*X_-VCvT$+_;v6dNq)r;(g577Z9}n;X z=N+}0wiH$X&@q?k`$pLe>Jtgh{TwbTi5F*wrn&%|>HVL7xvi9mbi_x3>6*;U)wgA( za)~!xN0`~tD4}*7%UA3IGZLF-zZ;D}mM)xCKYY;o@~8GCw;6w=T+G+c!CfXrNjKgX zy&ou_MbG{=>;6|~z)w{NHvfK?{$&WlsuFA;uAIod*XB!vMhyhqdA{x3)${;JNB8U< zcAlIqSUsY;{dS>swo4!8El&#`{(tcGzoqSRO(RSZdlK^Rhv!*G;#de zn$N0{YsBvGMK_J+bL-vAy$K6g5A*?DkNF0-uRQ;S&;Iv|zT#)ffhpzMUs5ve-=6+J ze|q{$CHcQpi&wh+W3c_z&%e+6n>7Cn=+7?p=Q#RLKYtqktplwP|F7Raa~Sf~10q9X z7GYipL7;b`aAH{|@Zc83nz@H;|E^rWTijZ7H@>Q}cqKOCK%0HInumgc{_(RV_Stu3vh|i4jT!YO-{I z=SF&3hb(y-`}ywdeZJG0<9QHTL~dHM4`ETJI6iuLPmj8LNlTYA(9V*-ch(nRxlDof z=}?O$JAGHb`SB@@9JIZn& zvWG^w48Kd8p79lrQ;vVZ=3F1-JYdHhTza~(v%6s2f6b)Q3uutE79bG9R1go8%iA6w zev$Tw`|;5mi&|6kqwDLQ2=6i+u^!OZHJ=x8biCEJFHb~E{cY~B%K9x*gI}8li%;>y z_NPlAbj}=hPgFQl9O8>N^MC1oU-UpU?}1?rE#x?Lg4l(j!SB5QQXn7KHM)>NN5kc{ zC%-Yx1@2LCy}J+s1^0Mb(&Z)lsV8n6f9y8Zi##9sQpT&GK(9_Q#!_FXo9o z^+|L6%}{&Xv8q#_smG(rrgS!JX$hJ=^1OW8bLRRX8-KpA1Ita)3JYFO+Bu6ZF4W3J zR71;|Azqq5xvLkR#0zgD04rFCIJ_dn~LpO5-5_TA5AFVh1c5o*SGfRiC}cF4 z-{kR_tr6)TroT{;_5|H^myM#@(aOS)-ZZ~x3Wk<9ah1&FxP4H?vR)f1+bSjSAjmpr z({&}T9=d9F#Y^YH#RIm@pQ!0yPkiSpU6yccC2Nfs6yN07L(rGa+^NC&!AY=;SHWNb zn%fyM8=q~&QhZ(3u@&Nyny+gtZyJC|;LtwxkXhkN5bau5!;va$HHdojb9RC8nEiZy z00gBc#(Gf;!NKSXx6}7nA_F1@+!cF6PX8D?Cm_4=$bFh%+2(h3*JX0_KCE)iuM?W| zF$^%_XP{CLd_?SX#n)lVSGVo{s1f+o18A8&yw)ZhLml_ilnz`qh%|xv$+Np9HmqN& z%2WC1<`L3LzWjL+UKKRL`3op&P}GR+R42sTwh3^^n5v4aDaHo;7C`?giQx4Al&jBT z+?Kg%MA1v8pUUUAGp`sZm%o`>GQ70rUBHClsSvT3ZqR2AUORb}RhxIt77P6o*Ew?f zIRZ!8)hTJ8^h91i_Ac_rF zF^^7=G{@c6Z+e2grC?$~E+14h|7!6NI|Ml~^{IeHN`&dUqjW<59b@OADgIrHZxd9e z<=$>#4fqVQ_NuKur1=bokZR@|v>M>D2)jm<=1j8hb`C^AV%Y`GCyL7y9<-TQ4Zgmx zJD-x@!I%&-xLGs7*Sj}zPU_3WSdDKDmGH}ag^5=%dE=`NzP3(R?+`Bg$jQ=s;JQpH z?EQTFt^@~p=;>9Y;XJ=d?un8KgAs#t2#OJ5QBscP^$ydP#O(*2C`a}z^i<#37t%sv z0W1QIL|p}Ft~6YyCQa%yrx0|=I;=II3eY+m?9=Xp-zoG`iW=B@eN=Klmrp zEr|V+SY=99$u>3N+izo)Yp529(Nov8OSY22nphWVmq*17UY<4Ph7Vu_lkg;fS54IC zQMV9`#~!GiY9EcQhs+e-9mB7890cK|>W##WgY=mJv;O(MSHA4*b>DQ^maW#Cz8R^$ z#Kg*^$Ps+;l8$ncAOmVuuW^92M`FdZ@Z)aB?7VlMQ1U^_rdxQ+pQ? z0q-S|$apXH$dB4-yQVY_E_nz{Phx|4gUE^d5Hj0U3T=Yf0qTG$ZESCJVJK z*#;3pXN&uUqK4N-G%trYVPL5uo}8ivWRoPnAg=7tqaWt@N(GBCg|-#Wb(lmJV5Wt) zs_0d&Ig}*zG3*+lPSRfUM(rBv-M2{##TM5&PPKzfgbd~Tly()kzE3Ll0>j{4C^e^% zH&BZdGNpy}qNM>Wx{v{5R6G5-F#kx9G(mvQQ zDaWY5uAI#ici+i{>LOa%22t(1Jqk*@2Ad2Y`x%_H$h5>I#BJEFW9svf7i|7JR$|%4s8OWfwfBoIT>9adl6w$0_se_!Zwc@ycW{F)fD_R8 zhK`RoCGt>1`CjxkobWNNh8FIVZ_qkV@{uC%xnIQ`+qZCg_i}sFkF_U-m@hjX(VI0m z>_L?Eou{!xDxg|(uk0<7n`d6^XgirPe{kYux!M)~K)OO}K?I$Xcuet$x_!eJ8&Sfl z@6a86`-CeYhI`*{x*suP^nIsi8&a$LM0l6TX-QpwcBklt9<0-aOf?T@0eE-@vaqb; zA*~vyFA9(;-xf=htk38@z9_maWAU!!JTd>z9OJDZmf5{6fcs? zHgHyn+h%c(VwD*9XeV^Qi0upH0aaX5)Y>ygQiAoQuljQl@s^CBS;^>nsZ_3`*9r^a z^8`G|@El5Ac$zyqA(Wd0?+EX8^Vcih*Y!QVY15In*~P~@XITjtA&7jX>W#G!zbx>j z?;zCkPy3`u`9Bu0yrNq2Lh?Iv*fiJt>)1}xTiD!ee=lP@IaX?Q>8e-s$<3yd%S4-h zPM!X-Ip7n)|6-K{llqZ+-hBzC- zL|)B>skAdfjs#WbEFE9ocFXngnp#+#Q*T)SBnU-8Vp&@>nKxZ_O1mxCh(XMgruh_O zVbKjQ+R_!>=4_VQ*38{gw1#xD3|4QDIYDu^S>n77n*1i-n6yAi?O{U4j*Gi|&N2vF zY=BR!_321kfOOX*NeN0E%&B}j)N*g${NOE`)ET3SF4 zj(uuk?J!|U%SJ5gFUO5#_ps__PChCL3CKHXTY(9}VL_=i;67^FVtVmmZGR|_M;prIn- z{v84dnR=146iG{M1@uD8y++GcfYby$u>7E6NnqRCp-XXZN7+&zb2;Yp=tEZPAAS3X z%HJ>R066X)MOHr1-230W6Noi@cKk}aR%1yWY+; z*TiuxhQO-&!^?iHZ#!9XU0X?9SeiS)i5a; zj1ab_#|wf+({S4OannhTVh{X!n{9Z3dW$_IYNS$Bcm2*A%_cg-47g{ZR)q69)>HK$ z2J2~j2xqUH+Xij#p2N+Mbpdoc^?S4!FZ~8ESx1xb8N%bnOXgEqwTVyEH}Aa^m&4f_ zqW~y}~ zsi*8%CCN&ts3Pu;kt$_x$JtLQr=~{rjk8daB1nG?>#9=COHGeMk|}+~3ZE~}QXrN) z{Mn~0kN^mjBKlwsDdP^b3%r+I*nR;7Rjs+JEi??=v4QKeRTYN?cwMv4)5f%(=ZWae8OiziA^fNZvF$e`+>5<&^WlUB*S9a?&X47t?etFwUoCnUw|xo z$}QqPH|#q+WK0C^w9rtXX3y%3=9PaI{T@mOM~0j#e_U=K5xp=_$SJhr-NOGkN|v;G zAi}uz)R`S+2LOI%-s`}@O*@a+Kbq@0xF_zbw)H-by)wkW^M|_>%p@8W%F2(@pTV+% zNf(3_3E4Bp=Tx_?T6C|eT&Ua4-z9m6NmxD7Tyduo7Xp3Z3*EM)*ziD&p&04Hgro;KFb7J95j6*q~aj@o}c&_jn>T*`x^4dZfNU$)1yaS)+8ovaa78z z=-}CV<*>JT)W;}?xf?#I_6rLKPD#d^sbRH>R&ArI99s}A8PZsoZJ1oPtkY?!P|A(3 zki3hw{d`VRpC0^R=J{PHSM(raiTlG02_Fijq*T>ZF|)dXRl9o4tw)3V$ffPt82Ak! z_d=o5PL3AsEHd5fNWjeYxg-@zx|_mmO>P_Pnw1uL#a*XFqa)~u6y6n{b3ZlEulKa+ z^;pWk_ReqMFaOQ|aLeRo0@$I*g|I*DD&kycw|3g;DW`Die#`F0rY8gWfdQO*5RcFq zL6cJK)dr86w@(ijq~bcYV+-!KyhaHL^=mK?0yBL(3|YgHFM0$yFVE?v@Ai9PzhvB( z_pI1pa7ph`)Dm*LcAUlMIA!vnW(?a%EO=>(eOhp%SLP%m?fpU7slFS-2QHU4A)kMb zyBB^s=U`KZ)v4KYZ^qwjId2KgqNpTfVJLO-XZ>TO&Qw-V=16a%`Qw$aZ{BMUCEl%H%%ZSN27`?$pO4^Afk+Y6px??~dk{ ztkPG4KOFQ)sEJyqUAv*=>^wjvqh~4+r zt-WdFG(~3!02G>vpaY!!I)M`L9nahO$KOCDjCGAkCyvFR43*6%4U=SH%$Pw-WSPUA(t?0TRw`Q8C5srOJeUiF1DOco6nk<)i z3@rF}Q2|d-#zU107D|h#1#-fV zeANlci`>Q6b`Lym>Jt_vX<(ED@BoSQwet|Z670Cj;H%`U_^lcQpZPIg@|A&%SE1HQ zUw!;KwBr3wJTzD%vG0mTIV2r+cC^HeSto7e9A(r*4;eW&Aut9dh_nJp{C+hM{`RB# zY{k1qUD))IZf1$H?B2KuX^n=BCM{cCB&~l`|3#O(z%~TI!L|*noDrj~W*9ErY`(XB zgB|;Eoisxs?4@D#9XjW#-d3^W3!VYXZGEc_znln19D_CBtTYFBv=xa*cM-e3qicdC zPWNiZmn1onu4WicRS-05N^#0nT>cBHmnl8fqHl70zl$$WQjS~bn04Q ztZy-&tt|3s zKaR_aj)%1X8#zZli|p}k@GzO(q0GuE#nGQ&-Wbd4fp-*CNLtn^G zHDzl>?31ot?Uj zA8sXnuI4kL%!OZMSEK!mm!q`##YNRX5e+7}v$nw#JU6on%!z9zEgqB<96_qTq+p%q zSel&YfzONvE4B^aDt3d)5BUUqfu4~l9BDIG_0)$6huFj6AFjT9a3P(7N+KZ=ZA5_} z-o~ub=3Qo=+az7eJ&YK=!)0Rlv}u(>H@e?sE$_KvSMFTlLk%-{*E$~4x6=2{#|n%s zQj?61eOUDE`H(I1Y5MHBrH5+{-S63@Fxp2Y?`hwFf;Sm4F8IG1P-AZyZOnQ;UEFo% zg}Z*jyV`&~-DnDfuiGJ|G7D8gE;zzIFyQgT?MGbpW+(WgEZZiQ?^VC;>!Q}b`nc4~ zYhLjZL4d!fc2BL$=})6b?S22FD{p|wsobFbWUEw)>9#J7=pNf_UulfK?vavuUnx{LU|0n$NWUYw7$pety4R^5;bpxMbyhJouOJFf5#e z+s-HikKLM2Q#iD*xfiS1D2h|0*!-7urBwqpix{mpmn@q;aqWCk8{sxHQ2+In+=!n7 zhnXm2H?e;3mGA4UE1mS!Iz3PAmO1wiJe&UvHkVviV{A7O6|=Ig6D@}PVv@lHI2 zZp;mkNy(@}-V6AX(T|+xap$HmZ*EK`SU7#-okJYBY~Hu)SPUw*uOk=gZQM%X7`V1+ zH&(aqSXwMP;Z>yAKcmp^CGEDUYIeSNZB`tlKw(2)WC;MarO;Mb1(=alD*W1A@9{_U zx<5QUI0XCtw%+oNgqUs48whVNeugV}F0yxZThs@dpHNspNAHPj`f6A2W)vqHUD-UN zJDzlMP;{i)XK>`2j^1v!LB?oPdO^kTi|UB-DU9P}&EQhT?Roo#fkivDrqz_!+q|xq zSk8F5Zl2wdCq4JFDJc!?fcFCHyW7O_bQLc1O@7yTs!-{2vwCquH8-U@r)mZ7mdVf)>TecampzzE>t8M_mTe4l{GCLbga>{p+!}r zK%n`A%$H;3MN1F0;rD^jgc=KHEkFI$7Mqfsu2`B|-KgbhYmSQAih2EMLyqfy(@WQ4 z@aU>joZPA93G`D*&5f4&BO)}xCc=rqH}+^}Jd;<@G`eK!V#DN-zFJ*l=3QV`@o{dOz5c@e-LB1xMc;R#=Y=6p zh$9Jq9;~be^x-{)W_;gUX1j}|e)&F*ZF9SeC@XhUp=qgF?Y%C_2DQ#mS}yK{x$ktP ze4&x_*RU#SkBx;Knd;l+H2NzeljmIGwWs&2t4;kZ^NLSkStq36WrKSxKqJVhFk5Eu zE}%0T4yKij_0HYZ%F?TSc(4*h;(N1x5s8EDHNQXj;m*SH^>>ReY#pI218o5^VeW0p zAj>Xc>z;;BH+GLJe0PFvWJNV$e%R_G74oNK-a5D^EUj}iex&h5c~X?4K8xOVUOK>W@D^Y8f_fraGh#N2mCKyC4!r@g^LLq+SZ5EBMP89hxCM+?y!}LUzLPKd+1#{rqj?}oP}QK+_yX73RSXW^Uj4~>%QoR)qpv%ZQ3ZmZR@2SX zYzytZ3*UJPuGZa!ch~4GOpJ_saf0`v#88VjM&{$S>b8;NAJWjMBKW&GSaZ2@Bz-?4 zR49jQwSkce2n{8_f$Jt2WBG>KWsP zkJ-?zD`ORC=C@|?W|c*z9eaX52BZRQ1iq>tKngU*#ssWGLN`=3;)_-Fp18rD$U4H& zsK^FR!X2SK4?5$&hdmXh&|q5u(19ZWB>EK_&3ut9Qd?%57Te5mYpNTz+3%QD+rDp{ zybO}pxTECl)0!?xZSlP&aifW=f2;EGrc3X;PM$YDaq94X9p0wEdhG=d1jdTQ(_@xA zPd0S3&ia37d+(s8qV-*PLnxv5j#5O5pny~bN$8*;O)P*kML?t~MOs2{N>QXr6+tP2 zNJmKMC{>gWh7Qs@0g~*yJ?D4MxpU`!XXg86KK>vK1KDd<-u0H}dER#kohV(&A1dHg zE|W2I+ox|i58DWTNGZ1PtTkY7hC>jpqNOB2;)PdfSUHqaK^BMPmT>JfUxPvr3Opb; zW?F2ym+u1e0GO!-$R4MNzP~3V%#`wT^Kh*KA1uhsM566~GUMR~guVP4=+Inq?S&75 zDAV_^X!I{+?jJe;kf1C@@z&UmjqC-B3>7b#Km7fMX*bLss4jpnZ zNVowRl}N?+o5PFI%5d&8T6Lu3H~ojRDOf?p}S+K=o<6@olNo){?@vdPJEZX5|2vm&la;{iyh2@^a8K z&v*QBL?SaO8}+H&Czn^??@XudYeh$1E<72vcS<7?P z>u0r)OQ{G55*Pt$JXsh(eu1pGrOM2~z)69MbNWk1wiJZ?4H}$z=6!wC)Rx*gRCTYrZU_yk#sC{!)orec=+ad7=LJH_3NcIyfSBeJPB@UmZlK2ML*~ zE%BwbQo{R;RA_#|3o@3E*mPG?+jEvLFPq9`Douu@*xlAlW;3v!5|-ZNaV`?txKX-* z;N1TOSW0gXR$ejYcDo_xWs_#}@mlXLM$~Z)q=pazVK|a(8(>yd?{(|6qplc$@TV#d zLjG&8Y$Tp~k&+7npolLXiM*4}IKR!yktAu}E_O>BHfX^>(X3=q_L4)1>T7h<-Z$wN zn%cWCP(%9iNUDdWl7yFb_iUI9#B^h^l+e+!5?e?KO|2+6)b@>QPR#LsNa-nm;pLbD zVq2_2x+#)XUbNdRHg?~}NUT{Z_uSvJs1#^|64?xS*RT^37<{hzk-6on`?d!Kz(Rqh zhTgFItu;^QuMX)40TZB|%;UTPKC=&LDDlPL8`QHg5rRDk;LW5mYkPH(8$3Q?m0WA% zTWOnF##gYWbatF2l$0NT$#`UP2CNw*Gg0 z`6+CEhu2wfiko(Q%L)zGuqW5oB%yo+wQZ}t*;5ELsZER$+59)Ic*Q^==oUn}A+>p6 zD@BY{Ioh|%T1^Z(Hp%An=zh08=z%?q2r6Jj2x1&hJPsdOvT8QX3%&S?;1zk$*p1tG zjDI5CvG{bH?b!hsOih*x!DWn!4IbBo8AXUe^w283 zaPfvjx+l#is_^>54-6VcLmkt`*$QU4vil(8Z>>9P2bM-%#5x`HF-kcqT3KrfYx?xc zK@Qpj%g9QO1q6kwt}f~5CG2Buhb_UYzemUDXT5QCD^wX);W_i ziRfOv3OWWpE!B$?Z?yjea|^7S*nFEAcDREBNz`OBfNeNgG88Lf3&*}I7K#KqG$P&) z_TQZ~0!Nmx0`9b~VJ(CX%G;$G9DCX%w+kklRp&hi$(??#7++_t5nNA~uS59Hp0f@r zFMA(%c{VRMeDnZC#B)7#oFAd;|G6($l0OHuU}@PYpjC^Ko#10H&Ffid`8<7PK9AY= zo=c;qoqFRfTZE&$L{X=s859x(k0Bty8Wydr>x;6oo;{E?B$!`|3^MOu4cpatnAp<% z8$21Eg?B({!*xmsl2b?iy4>?vMuEO|1U#h*`6Ei}`b3{vanNga8>P;u(La&-$Y7Tq zfBBt8%X}}X{x+Q-ub?+;U2?w4XAIm0TV2y8#Z>t)7<(L(k=R`Su|8+_TjtbZ{*@V1 zG=P*66ooP#KJLuu+runurs;}Nf2cW2bl0y^i}f3l(_BOUZjg|Hi+~k(OD$Ft=Eo@< z4C~LuTMLit5tJYIRi^HNr$<9pvT&+!DJbDXlfexZUpLxUZ$(5@ICogrFs^T(@ivA5 z4rp?K<)w#E*CFqfr>V+;(BIDQ2)mI)^%6#w2;SB2eFY}LL8w`e_ljO|hshUufE^IA zeWxCF_2Bcl+3a;fhT%LA^=7m?_0l~P8_SR|fZdf$aWJ1i>ku4)drDS}90$L*{V_P7b|p> z9)lmw|J8B*N8Rv0@LscsaLt(hKxi&)B_AJ=!Q4NXy#CDSn)?dfnH(npT<@Rpy?CI_ z#%A<~PIC@X|Evl^bGbg6_d~+qgV?9*;|TYe3F|B;aI41kYlC$_&iu3DtmDFcD>do< zpR=e-oBl=IieE|1pA)3Bhz;S6=Lb*#4E{aV*oY-ze=CF8=?H@dNwz*-zy2>q?h%5lSWIq z>J2%9K`V#*6VnHd_9I)MZ3w^tQ7u1hdBfvawG%7AMp=iNaN^FI%}HTCQ%b?d@wm)& z+!`{5nh^5tTgszEI;Q$lceRJ3kEjn1-#suiYMcQf(uM%g&;qvX3&*Ooik`_N%&|hm=nDDanNkMv9#coVuGT*i2zfZE)~NlyF2lj=Ug}mj*vOw zVD#x(bVLY*w95ck4+;-J@U{}wb5~}eI4QO4nWr<$og3z1uexb3QBm(aV=Sfb^6iyc zh}H~@4RyUkNG;0_9sk2@ZS5dE*ZjPisz1N&&{@!~S-T^1u{>&>5u~EVsoMA*`Ef)3 zhnD59?-YgSUznlZ?7s6nQ*GgUKoQ#yzQAI62- zK>(B6+|nA~ruNn0%k3*=++nmj?Y%@vaj^i5I?iG`2w^M;LN9x#U{jK zZhkrS5Xp@iTL99SgzH09r&V{_^ccr>;?dj5)Wcf~C&p)T`;~mij^bw1FP#PaA4_Ku zt84dDCxf`eKMfM(MGOPaJ)L!bRff6fL@t%uGt&*gX#}xW00nK~{EL#2eg^u=BEkgY z3H!0@$EmXw614{vq!|Oii$)N=lu?=9J?LaOb+!wwChw;lrV%RDo-)wrt`U-nd^9^2 zHmwmTkOB+^QGbRR(3zOa7$6}@0htiQVK8&?2us+@(Y)kox!(rrzkU9clqbWdWj8d2 zK+w*<#T~bNi?a;Gk=7oKH%Ch&z@bmA^HF#w{hZ!gL9<~h{wS0n>D(W>11@bf3@jQ8 zAi0I;*`bOHZs67Nz@z74`7!!oqN)A;+U;vWEdvDtecy&R`{hinG&trdh=Y{bGm4+D zKMSUHXD&qA591kFsxv9dI-7~;8ps)K8yjX-WsN!~3gdcZhW9D+p) z<%ZZYrwi6&1ZpWyUvZ}TtCVuqQXaJPVLSvs%H%f{vXZ}fGvk@feIE@F!i zH_VL0+QK5B{w8vmM0ai-tlpJ%pjT(=Swm!9Bs{L1oW`<_K6dKHSq&ONL{oLy(TIoQ z2A77t<|0#POqq~H6=JpiXy96={$~tk8C;@<+O5d`D)}hUfMB5B;P=WV&EhasIa0Fm zQ)FT1h7SX5!hlv(JxYZ$<*ijEp!6@AO>nX{L-a6F>$4iahWhsc?r>5Y+cjeeB(S<8A%u_0NKkd^QRj zF~shPXW(09Z8|cMt1<5m#D$^()b?#J?J_F`G=P5QnY)v4Gxx z#yCbO^DLZ}D@f-J-HE885_@|519j5WJS<*5qoh72n+fqafSgjoBERr?=7XDt@|s`ioWnq;f+$IXBxvf02u>AGf?Fq<5T{sjw*^mg zUOQ=ntr>;E+vNAHCg)HMr5E?k1c9RbS3ccHt^S(zBa$o^)=C%6$+-GSu&9rThJn

ZEQg5~y;D#!WYZ#QlJJiL>YoGYkNlAC=P zGH#0s&ZVDGH~?V~T40L>UP7LB#c7y(p)Q1ew{270Q;|K*CLPUC#co*?wNJ(INGE{v zI1hJbUsh?TX4-&mjOmf)moxXwm}cyW3}i~qrR^&;2Y5$Q!yX#=`0CkQwqF@jg@@Eb z|17Fz*r7a9<{X1yu{%vI#h06Ly-#`;-I%bM2{IOXEoC5j<8GlvNIpr1{mF8dv(A^F z8b&0P@8z?&+5F0ItAXGVQg9fV9m8*GPRJ)>J^p&!2-w(wYd#x)1u5n(Lf*J&9Edfj zI47DR#`lgUG@g* zkhZPiv4ot7?fUFM)&$d%NG!#9w{NlXC}opXowS=ttAeVi>8@YT+fO5lN|`>s z*m`a#K)jy_4UsaZO^E%(ltKmh1M+{cd$7qsZ5uaO9#cxMf78LDt6SrJzl(-fDb5FZ z1YK96P#CRH;F*dg9&1hVZXDeAzW~4>0fKEsUURzc$9G-$H+$$2iB#jfYOtLG5T=2s zeG0m;cO;&bhF~T7&oyC0{0xYRVv{Sm%48%1McYWx3zrlDav)(T=BUx})Bcs#6)A=E zUa#a=J%H-R2LE&XIC%YX8slp9kPpX3J5bZH-($pS{NsS za)M9(1h%yif_$}QLNK5KSDu-R&MnZ1;k||a5pmtZ+S2^AaQ*rfO^C=Ca3$H;=j}wM zH}&y6&sW0hLF)uUk$h+%ROphbfBQ>rcCbPU<=C+@S8B__ykus#_+S*bl(;qNKOhi%4?<>dsjlzqKOj#^SZ&&4%Q1=>3D$CGRVnhZ* zFUhhcWf=S4)Ic4w1p3%ot(#5BBvHxI%SWpOHOuFRGYq44?g}h9sL)}`!S*E1zvqK+1#peXCCCu6V`lAIt=+r0PLft2e5kkbR2uuwMs`Z@ca0XhokEY;nR z;VVJ&RWG8U$|*WZ=rczE?|I2HLz96T(8I-_vg+OX6B0$xo<6kVUj+s58UlEyr8dVq z^t=EKH!bA&s=`{bPjH0HS+pMn`^CMzj0TWr5fVaY7-27krPggjcdpl>68qdAEx#by zpekf$gADEPi!yf67mSf_$4>?RgyYv9^ z_f{!by-(AC#&ZMOrE5%Py2u@q+b#i%pQH5&r^bMK6phFRWB{*`l}B87qc zx-FS!00T&ywxaN+Lnwn}eM zy(m~iJPZ1+44>TeN+?8iU@4)@XX%PppR~k-xvg!#8zpx(x3Sy+FVwjASGGsb5Z%8K z!uyRK_ggPX0P&iip3p&jk)4JuI{2F8h&3*q7xn z!fo4;+Pln;yIF6YFELPX-@flE{nmykol}h>k=kqjN}y*R)El>FZif=5Ug91OE!XY< zNZ_x}1k}kS10CMTmA`kE{fGO+J&y#WKlqZp(P2j*jb6Zn!Uu zI+M^N^=tNIuU`f^$$3`pE36D#%$5IVB9`-|n@%o?v-vGpLw&OlWqq}C>jvw3Kf zvw#ua9E*Rygnu@WvxknXex;ZuBs|PrboapxqWX<-UJyaF&afaHS?=pFVhHm7f=*M$ zeEo+x2WbaNi|%-W5PdE*rI(qU(hnk8Wq8@Za?GK z=Xc=g>iKf0l2G6?FwM%v6q6nV!012JFB15nJw?(`0hADP!+KHr1;&hKnsB)Vb99@r zbYG(KwD}VCHN3<&_fvKmb74_M`=xV%3v=Ixgx@9*RE!@R`pG`J-{bK6$eUQufweUE z7b91E6!3T_r44nUVbl6UznX9?$;=-i4na$;Lbebu<~k0awd4vPnUko$R7BYL6j^BM z$2_??n|S^%++_>7RvmYv7J(G~zVyTF0AK}Y-$j<)(|Kl|Zvzl){%C_)up>oav?2kHP&zh0vv6xEb3AA(1llT+be4;6v!|p`yzsTiN(AZQQIm zjHjP(U;(6DuJzEvuPUc_7S{2BYcp{%tbHoPyTN7cX7@>C@A8B0o~g`n_E6qmw#Zc6 zlKo!Ry9)-gTE%}JI2tn%WB=4N+ctYIbq0Sm&86v!Z?B9I=sQJS^G$L~i1X+Q)5QJe zO&cUO7f_X%VIqrct%ieB*skv83Pb+XBhBKczRT#Slo*hwmX}p(W4=ALFvFhAhQzCu z939VbGxKQcdIYbtghEKvi!b_xZ{G(TIVYuSW~~MgF|F!#5S>XCf`*A|3UTzqC~M|uZMVLh9Tz_D{FbRjXDfX` z3A|FJ_rh(7;il!|i73)iR85gGp@Qm*Y)kdQw&%q@Bo>CHj<7{@j4z35l+oo|pOf^~ zCEgCJ?`E6mUb>svkP_!jA%vU7$?EyGto08~f=8*$bWrb9a=f5ZA?mR%0yLm>l{jEB zoP{%@$;bg_6==)hP#x55-~}234=K4AUIXLe1-*5M=+9SE3~qh04gU2xx6M=de1v|> zM>4erIFqD_4cD7H9=}?G9j$j*(y*#d+g9T@`s*HDGh^g)sS#J}M3F-s3uwp(QBV%T z4eBDiw~fu_28&%vhVSEjZn5w_yoZdOB2W$G>RacIZiNY^^u-c44Ed@SPwqy}Cv5+i zuRqrev37i*&9Tkn5leR(1|*DvoFf;QV@bJtB>R)pO9cT9d#dq0;^m0~bAv)wMm2)Q z`g0ex{L$qNX#yWODsT0@3E7pe>VD4PN`KfCd67=Y1?FCCrn@FdFg_b!v&XazHLwC6 zvPeMCo?3m48P(c+UN#{Y{NUAsEua?*y_=T3X8aUlEqqThmK7MxmY)#1bZZap)Nxat{_zI&AhdeF}Q7_`0Q!f^PaL8V5(drWbqp_1*Uzqy3 zW34+e5$L78#P&Vib~A4P9BIRMuU}F$y&Q;)Jblx|n}gYEO?Ya<)b@N>3Xq2;3$AUh zO+L&V)4XYW%r(*~=Xmn~7EC{KWQ~va4}wBHqY@+jP8)m88yQ#Fz}KmK21h$w-5YAqVj;6y6B;f z8FK2^_B_*{L%r$C4lu_A#=ErVn`@0@VBf<23h`&7^XRTs0Oyw6DrPRdVw+t#@^hRl zOZj@t@{Djz_eV!Xc)3@{kL23qDNhvrZ+9~s;sG-R_*{Si+z@(qo-bN^?VE@u@oTRdf~gq^^G_BO&WQy=AuFId-RwLv+69>B9ME4MB-cpIrwT>so~KmOZKU^PHM`?v z>>ll&Bhz`&Ji4Jxz{p3s#5ypiLP|u7i5rDJN(`$em=c0Z5N+V2`<-wJ@M4}wMO2{FpOnrR8N6gk#%I& z1rGUGtxnjPAO|55q)+SP$Sz6iYqLPtLcaP_^kcV5Q;-I8bX87o;7`*+N;Ish9M_@JlhM}n%?rP)B(=D@K7E&0L*SGJ>ZlJ5{ z#A$kftqe85F9*QF2o0<@73*5VR#Z>Q5#T2`gw!Pb^QHB~^S*R%jSE$qzO*qz(22*p z*IEw1dH2Ig5sPnRXN<-e2thlRHiF+rn=s%RJhB;pGc&V0aWaN^{jkXtKPOqE)$tVj zp>1kb!d>1eEW}V{MrHbc*7wyGh(=Z~_chDuXr z89*93H*-){BdGHn)eem7UUDt4Wnc_siHZ~fe(_0}1Fd)F`IyGuUS$bRp%WLU3aNT6 zIM}_%)YIOSH0g4ucl&f9Z)qSxo868Fx& z(}H@7iG5a;NaTjlj*0){Da1GZkQ8{cf&)MKLJyoR=nX3d|3StpHSCGq&{DYfB~&6cwM4jdl^ zs%vb66dgx<>1D{VSTg7o@=}}#uQ@SlWqr0BUSeokKyr)^A;X)jsLM^=DMXCv)HdXo-LpG}X(BH#mA*sl1|GXuF?G^^_`DsnUR1t~1-_$c{^}Y#0NMnWIC+MS z`LiYkmbAAGzDKq$= z*SQ|nl0seggjtBg<(W}=2NnOsZriD@cs>QhUi@9Hu!E|0P1s=Bfo^oWfR0s&dz@$p zjSsd>QofZJkOYydy~I`atQ^+b{SFKSs$9Hk?z63cSpnVL`@((Y5QjFia0LM3G5#rI z)Wb5UXP-=ol{)0lkn;7;jw$C5-96h*ZeH=B-eGT8=yks7(7SUno^2$mKbHUOddg0^ zhtoKbDd@WE{M74C=UNOFRZn9ybB!ijO{b(GagsuG=`FtG4NDL{u+fPMZe|El)K<#JCO&JMLVbbM8 z7y6{a!w=jaenTpO*Tn^Oa;EtRpW8L5;NP;XLN*7IRiaVf4Y%RYtv`q5)JOn%U4Z`| zSaf3AB1r&mD?s38u3z(L1I{XjG0}{4insZs1v)1|1(JaN1^bapZ+x=rYXt}Dtp)%-jF9!FP5Bqx?T%{-`7MruKe~& z?S0huPX_2V;zi9w8gJqN+19>Q=btv64_#zitp;2;GsW=1<2-pa8Gx>xOdx&{B9ku3 zp;sYyFsd+7M-(`oS{+6_*DbvG=r@R(BVSRM=ZhyeR<2b2miy(3La!ai*YnlC^v4lU zq!#AiF{&xTT+o0Z3FmAfZv)=tIhTDmt(A|@z9R2KLq=Av{`G+Og0I3{`NzGC$r2?# ze>v7611q#>jcYHH1jXVx?YRV?(Gkd6)BG>TP49=v-m_lb!1F(ap&2;TdG3FT>-$KM*Mt`qu6RKrWB~4eT_g@Q*EN@MyfT(FDS5jNK5-hGb|e+B32} zCOr2s7Z~zit9*<0QtykL0I{}b1#F|)LRydIcIpx!lVz)!F3U2~ZiZkDXc^*^Mu9KR0dx}}t0vtxVOf9E17hw`l!050#re(M3B+0C?j zs?S8pK%r`N8hW^&zvOVBArhk>* z|5k|p*BU&?_a`OeOku-X^FLV3SldvHMWMid1UCQXHve1N1)0D9zmLHGm%&W`e%}A< z(MtcnJQ@pK2}-G+p^TbMtF!3Nfo8%0SW4>GaDax4LL6Y#65+byP?1y?$qlSNS144& zLUZHeig~(OS|Z6FEy!kbH)eD0Se-90X(1KdGHW7uuRj{tHksbTUR8ClG}Vc?oC};Y zuU`R2>tFq>m?t){Q#~V03kz(|7gNzyjqnF!_FrGgfpE?i_!w%LJ$6vsS?$wGnN50T zW2*#*EmhTM+1I%gf1;GqrAp`P2^yLX?3xi04@inkWrb8smIg>J62EEvEE5RM*{0+{ zFMq%Ab~DgKVdTn%&@+_AoKQzi2050>*bq=p&>$x>Wq}0Jte;pnZ5N<6`Vjt-G+FQf zj{3(xLZj4fUVB_z{PjUN3lLX-JXP|OsAe}RAfPKc(uD7HhwiQsS8+Kk`)Oo&1#@sh zS=7gq7;9KF=~^@CMM99nt?qgV&X6Bw2Cx(xm+~*a_J~8RP=k?pluQT#QGr9#-|smo z=d%}3)X~xT^_zl~tX(-MJGqvpK6+cPRmP5eB&V~cblI^x0 zZ*PP*rtxb9gka;%Tm3v*0IXUA|4qJY*yrWT6G z?fV6nd0TDiVO=ZEJ}INneB5X*WS~FX$Y>QFsZm^EXS(DrKxAOuDYR(`t2}7w4>{o+ z*kT)eaHF`C9#CgW+Xf_IZ5IGlxk~uOFJ*Ton>RT%;DGMq_L@YSZJ?{~9G6f(cc3s8 z|F&b!z_*##>q6udGIY82oAKwrCsk5@JkUnNY`Kuk4+hcXjC&QpXnj6YuRYziGTG`A z%iU-S*OUi3Toy;RUDjtquyB$3y-bgG#LO3H%%v$^TBve26Mgp?p`Mv-IiVk%Ho1aP zSoB$n*YgJ(Y-$x9Hdcb=Gj`iD%eFo8vfw&L=mpN_!ZF&c*=(&e+4MptgNmbG=te;k z*Y>Y-%rSTA${w!QguF-an>s5_)$?mPwF~jgjxa}2+h0%g#)lj9ps(~9sh$P0O)7R$ z3#0G}6_ztYi4=OhWS-&H(X7J($HRXrGl+&$(cKt2!CnI1|hET|n*fAq78+KXXWU6HKg>US-C6fbtkGG@ENnt?_}DCU-Y6 z#8IT3Q2~1MPG()Y$nWLB6@+~X6vegcIWu|^dDSX#KcPK1vh#$}=m0YefuIThSfd*X z!Q?S`N$H&O?X=r(xniH2?FLFryYHUjbx`btb5jKJO9wtBeu6tP2_7=OgPBC@WQLRG zIYAN6CpnIGQv~^E!5oocKQzNeC=6Z*HmcHSzmmCM-e&nE8f)V+6V(}wOwoHB_RX4a zD5R*rMGb|ZM@5Ec87t@70XH9n1KJO4XF_FLT41s>oe|Fvvn+6yZ-auEACkrjq<@-E zYxYZmnyNBf(0B)ekq9G3jVu@q1z&Yt1Yw=y7?X5@|ncx$+xd;`AZOaPdCj~3DFhyR= za=8286Z5tVp8IKYuw)BURQsi{tIaaJ)6<)qN4`O}nKO2GROzpbNe(P8ZEMa%Ye+>i zdxdrdhtSLV`i_XpLx!s4`*&eCg9loeLTy8@UEN84>_DUn@|^j+1zqdAMFriUW;s<~ zdX7T*VE}$mhJ0YguB+mX>WvpuzozD%XpQ9zEYE_J9y%<8T2{0#*zdDht@MdPJ7<$K z(3VYSNRK$BwEL4h2aPB#z24y>MlsJ>T~-|lL9%CCRJD6JI-cjri|uU>gfWNqM^VXL zG{A6ne7=2Qf?}GyNS)Rd(&8EIcjH}?atc0F`Yt}A2ayvV$Pc%B0JE5 zS3MMZxWX0q9+op>*Iw~}ZDy|*`UOwWwtmxzSj%!ObI6N3VFR>&s@p|>UoObsQ{)uQ zQdR(jHgo@E?xZ@rRHgWhdmL2ufT)-SygaISQr0nd2iqhj?TRRIMd-a}^}>x=FJAE3 z!XrRvP^0$xw>gvt|HM{eZ39g8#12DR&^H&_QP^eu{xk>C!!xT#t#Y@Ze2ryon;wGE z*#0)7kO5-m!Rl@p$%}&V>2nFoAf_VKgFMKeRq07)cYOL@LJ-A|8ZkCuml`YbYt;H- zCi#qHNA2jc4gH7Ha$urnN(lH`Z&tv%t0d%8W%i`cy^~eOA^D>Ytg6JtxYgRjjfdlc z7gLxla?Y9`mM92F5Z%X`E3J2HRfiVexR46I?FHr>9czmE4B~g5v)ta03wpwz@sNC%d~iJ9ZiWrIq)WOE45rmn zj$CBi#FZHrS8U$GO`MLwRIE?N!E>m`=0APTzN(as^FwDwK2PV|| z@_ry>^m08;W0&v`xYS`uhe+Un8IqeVO*3AyBWV6s5d%1Tbj((^R_sk}W+-{J7QfV- z&wOI)i2dZB;%bpLs&pX3L5g86l-9 zQp&28RkC%yy60L@PR1KHoU2P3ZRvDByU*dt9pv~K?D3z-8g#z0H8B(AlZe`;Y$wkwUp0&A|M<%=~vTc@bo26O0hyx4j**`(; z%`BPkC{4o>9-3^fX@9YLeaeQk;-n0D4-YEU3e$-BQg`1YZrf;3>>lGVRdz$rSN9gj z+UzlwKQvdwufcBSh}b?CI6Zz-MwCq$iA_q3MW*Km08h)F3UY#O&5G44eiH$T4LP1~3{Yl%IrLZ)Z*^Fp?JPoN1o% z^#%KW!vq8zh8R&j=x%$oL7NO9B!&AMk52mPnbNm$CQ=nVoc(JA?qe*j)MpY-RI<+&Wi7?fmY*U=BR zy4>~d%ablzTTA`iBg>ARtem5w^PsY^ua4Y4Xz3=)&Hi|}q70e>cX)`B_MG@P4y za&U)=rW8oo3KaRiqFcz=PD0I6peG4raazS}$kjgPSW}u*KYi*pk%5su;uRK{zdzB_ zjC0C;8+8t&#IP}S1GLUc3+_o4p7i!4er#!eN^Ici?DF0skt$xBOuqV>Kj+ZUW~;-( zh3%syw0Q@fM`vuEcdsH=;mO%w)b_$F?e}$7c_A$q&jSZ5d}fQT*N~g3F@EZ zJM&u94zE5Yu~mMOG?_=U3Cdl=`);tUy`25{<-z!$2Z~E#U;Q>y9)ZhS=cI{2z|eq5j-*Z5?tvuq=Uxr0rdJ5epj+QMVhIwku8rFh&t zcc?#$oG&-mD&i2h%)glQfpye9uDr$Nm3WxEZGmIcdeT9B42=2Ktn#-R)P;1z0Nr8D zy`D1K$Hn1pIg4K|5h$hCbT+q^=S?*#nJRX!d*-XkhY4QHbKS%W`zHq$9op_Q`zXaI zrz`vZcweG+7rU3N?6}e!__n9;Vq}qU;Io-IHld9oYK?v_Dk?Opr;Zj@;Kn}S50#%=$yIRJDvz|vVJz%NV%*MneXzF)qL zHR&Bc{^<2DuJLf>uj~V0;_*4 z?nRO+*OZZd?e9dGymHN6NH@zZtnCqPgL;tg=sx3`@y9q63HhY7QI~Bv_4{04lXPeY zv;M4!U}ycEdny9j!N0fGmf}zrhB;vB_uDe(_MAUM9iQLZ1s+||Vql@$(fPn*t;Mo- zkWNlj3P)VrexS3Njc?VYVtN>wSH?BwsA0KyoR->`1oII-0B4mtJ+;00;+0mH0j2bF zHe49n&tB}cbtJOenz0%Eqm0qN+rL+|CSJJnD)AV1K=j=U!VsG=BOUQf->x|DJk;9R z_ydJg;}p^hPszFzuTRe597%yS;j_PJa@Z7jW)YO8Lhk^LKR*bUNw?k1 zELhj6Pq`NLvr9zgBg^s!^K?tJQ4cNNc=;c|48ITb=S^NbtdF0PuL8cYnv@<-r;F~* z5=cdtA4G>L9js}6&A4+V$Vj1<h$9};Sm6I(IEr}Si%-W zu6WYkeaWeJ=k(8f^(~8?TNF`4nyDYOtfxkP9bNrQl4^c&f9%%brOdu~B5unKhmpoH z%~zwsuAL6$Y%zi1l1B*eDCegHqx`s5VZvE^^~y5~)MtcQH^-#T40C43l98P+BQ9)s zB_uB|G(Do8baQ-(Br88r-B+;OP`u>rpJ2b&a+oJcNHC&!J)Zg)PkTk#N7rB_&Y;W6 zdiPi;Ws`%AhcLNY0z1uF;|u_dalcMM%profV=g^4pqBz$d9UXbnL|9v{xM=dP#>AU z>=(_rd05dv@uWi3{+mZ?MbYif1&KDvqK;UZqT3|%l>jumfAR0F@y;cm8ky-}9Pxds z2R!OfzlH=Ja&=~L>bB(yIn6j_*@Ixsj=su9aHXIr^5YP`(L4D%1{Zc7efxBFD`-$) zPbtN4^!Is{b%Dfpdg;FlUSX**0Y!wxf|M*~qvEhWt$#>-23x0P720l|Yo(xcW2HlTt1pA=ek{HvETOh4cf{p^VCUM{=8|)Y7xU>5LIMsytlVZ4yXK$;uA+R)t^Vb-n zz?9v&F;rCOmauhUex2~9h|PvnqOej&&NTYfx1?rgz0xR?dHcmrH&*?+!-Tk>XHbw0 zU&aV%dAqiK&V8zq+w4E%CEe}z4X7a=)`H`4Z79x5`dbHKq2%lwYyx+AZKx!EmJl$L z8z2g&jU(MGlQSY#pZL`ck^BOi2d=%6Hx+g7>e2g8fZr-h@A@7g$z$bXW4H;iu^m!HouRx#VfW^TyPh8215iLv@M!y`F44*;ak9 zd;iC%dnEDbEgJjB zf5pEy&jPU5RCyB5xxaZt9~0&p2q~jx9jm)h4wm` zkgVS$0TEmJ=8xqB_((Kwz5K0nmTo(~TR#+h<q_TBHm&iQaB{!F%&v z52$r|7cJ?WaW$mThLZyWCWQ}FS>S8mI!~2)gyw!2Pbd|Yv^>f_0ykApx4@TchM+U5 zXCd6=q2Lc~ktgI=_X0l5=ti0vKk)NJ50vh3{}uP-Yy#t2A0iMupqk!9s~8N2jE zJ2|VmqI32`>~zT~Y|W4wohTr6a8THojCkwklC9mqkGc$=N6*M z0mhf?5@Ye=3CM*1DuC_&;3zNpMwY8N`1))DlDSN}d3uT=V0HNPkZMzRcb*7_^?2-T zRL3I2^GrmdE{2q+-*LgzPTD?}h@WOU98V2AgR()jkNnIz+SGcToDGc-4HPuKLrK2C zD{S85e_1UdHt8$*3uD=c%bejtGpfDf1M^0jyL#PMh5WtouKr2s;DoKO`F$?O4s*6G z3Rp}wV%9YK?H+5CZb{xcSEYpD7KT)zwPpqik=M|Oug%`j89L3`h10&hA%_MZnh5ma zjrP)O!Q&#Zp*Db&Xw)y0Uhif@%s;N0%KG|8_o5pfk=m*yL5u88Dtzr#7Hw4iQd;k~ zj-hqpz;Ru;FC)trza4*bW8_nWz1icQGMBEOfwlI9DGlN>N zK2!?21pZv}Zu~lx=?!kdPMm3Zoo%L~89wid-9|q4Z=Z`lzY?w3>6_I^dOb-#9dkV- zOS5q?{JqfuU901sVz97;H)*6CjTzuAWbeOMG@1HvO$i=J}_LD}~7P$pB z9CY!JJ_gLH%<Jy$v#Wcm!on2%lC3@0WC?x)Nl2Fq7HpKgLDSI`_Z-qAEM)8~p`+q^WN z{kWYI`c6Bdy`#)KAQm`BpWJm99HFfFn1KoodAb;&a+-YO9Y*mJQ$zpD58>8tZxwtW za+`i@-jA-sB0f3cP*#}##nM+tHTnL3k8bJikd%-XgdrfI0@4Ul3Ift19UGyPl%Pn5 zARW>@LP0>fb3-~u4_Mp}pYQM4+5XwCbDf>*e&5?WUU4~e8T9A?t_DTF#ByG-U6kyn)yxN#TMK3B1QjovZA zPKBMu1Ora0vRx|dvQm!7O;r42>aHrGsFwZ&AFL;uy2{s9{=pO7C$}EV#+bHd;^xJKq=s>7733PbzXwmWR!U_+r?{R-d-ZtK6 zb%4e?7u!Bij-o3wCCh~keOpHyHMSPoLMz1ie%dCFjKtZ0nOfyx@ND%wxp%AKd&J$# zAlZ~O487@tk7cSg&4vRI8>9FCnz6rQA}vfTC|Rc(Y1@Nw^WZ%iI9@#N_jlW%EpK0J z8SNWgRsj9S6P+h9oMaPskpQ zRU|_slRuM^uWK^YKJLj zW}pJyT7e6fbv4J7cXP**n(+>)UHIfH;`IhCqTu}XB$P24DRWx;{gry5+P=$BTV=5F z9?;UG-Z*-5J9k=wquhOe0#r+$ju%$>}XBu!LT=DY2qORmTN)5-MS%3Z>@dvVRqbmKXph) zPAZPVFxr=q04a3&VJQgW?dkWsLpZTc#Jj&) zef|TS&tNoQyr%(GGPZo}2x5OX1X>IEc%_3e?F&)xQ2BdZeLq)+cyCB>G_JdAd2%0E zvhtsPh+4VGC0jGJhqm8p{ZG$@;~KadflX|y0oLGv^ydReg?>)-3tBquHj>jD(Rb^Q zMIB;^)o9!KTQoBAl=e7?R8;=(V8{^o5Sz7^3C}G3y-K1~9o0N*TRdp(IANS)9euC9 zETM_~T`Rj?mg{yZ?w$CFx0eZ#M_E<}Z%=G)vW5Kw0Fhs0E-q8|NKLsydeOKy_po-D~#3&2C6zM9}CW9D*j9XqXnL_0VDtzQgb*WuR9#rFW9Eflc8`Qj;*!LtCAr+L`0E96^lUU*Bn zV4Ig5aFkq+lB8c5eERL;O8X6Y?JafoR?y-cwEBae=P;EMKTqw$9MDdZG_xOYMJK@f zzW;gB3szJ+Z+`60zgFbA-|6n{bgn!c7rjV)j&m7&8x3IR^xoX7IRJ2z8?xo^y4-7k z$W}+s1ABs>g*=_!`ESngw2?|eC|n78!V?v=GxXHnpPT}YIb#ZQ09)!UZq3+0ULP>g2+`a# zsHBpt4~Ly(+PNYHtR7N@LN<>Z3>}o%+=}CR{|*zlQjQ2;FFZQ)FmYlI52d=zR9x8h-rcHamny9uPvb1y_g9>jXWiL4Ez5g@B~XZF&-WkmSp zuw0-<+#B#%u4weqUF3VLWMscF-%|{9#~Kx{jUCV!}LbQzsvkxlW~idh-G1aTrFX^ zW=swB4dF8oS$xj-%fqS*FAh4NY+#$1&s744M-`i7*=5yfk8>at>z}SCwjk_cqr>PT2*EUx2GUB4W8)wUMkoYl!Bd>)mm*q3}stKd7i)c32 zq-)2C!IDlU=7KAO|()1u1x&Qu7vBhFQ~$6}il-I#>#I1E(+1J1X}4 zu)NFGzzv0)G73|-=Oh^Q<@RL$Ol{;8(Slj3E)MA#*`EuNzz_x%Vfx0`>6MIK{(Gf|1(-9DmKSA&@fjo?szL}Q0%dkdPmyy__9hR z#d_kR3^*YAm%Gg+U9dZQTUC?o&UR6~lIB#kwV?^*f>ESgE@ru@?<(hIcN$01LVsqnG1G=6klVc}RoL*9_|AtEa}-`XK#1mqGX*um0WZ z-;-qiz%_5`({C!{RE$a&Z&}Kf12hURPGvq@KLhuM-RmNxMx<$UAyZ^Lu1Ic(H1tzh z@8?AI_Vve7C3injSOZae$x!{m&$;kD##1rx1kYV_ikq@y1OBqM@~0=X%YJtnd@lR8G|2D@6k>x-1AxS9$=eJ1m6smX+yx?t)|D^xp(0j+6G(q1>Xm&*) z8sMJ06JA-%+u9bEaRxZOn3B;93%=`l6X#lLerQfSCs5H*bD<#IT7_%+4~ZW{2O|E+ zI3BCRdfK74f(Z=Xn7x_2(unEgf5*?X&ou;t+mP8CU)KjsNGs$iO&g7e5xKvdJx01f zKF#_{LgTccr@BEquRi>wMHyf=*fhXln*Lh71mCMzVHV#1BpOY)8_wKALzkR%LpOpA zJ;G4d`ee6h58*3mbvJTo@}2T^o?4A3sl<}YazT1dSKJ(b!VJpRV~a9ESU z`kcVBILl826}~T!@GTxS#FsdKH}+;eo*$Z#FI|kUX!p0jQ<%Un z6a2HuROY?Pw)2v-oMLCd>B=z9ODanFQ|WW7ySwYsp2rkJP289sLe?ug#+&r_V{R-r zdA&)P4<~}s7rbq-LIQa^GwwFa%J|vTJEj*N@Y6p<FN&oTfFVCku`Pzar{2|NyCygdxIVkQ!;yp z6o=sRIkDLtQ({*yrIcj;bH#lQZ1wfl{o>|%dgZijY}cJJX?Mw$yT50}g5Y3>czZ+M zJ7qSB5=mItseQXeS4KHliT14${GsE!zJ=I?8K)G42OHdq`hb@3g;tCo0V6_h9mCJg z5a{r%g3tCo*9{&#Liu0VoXIb{xh749CIn(+^{A7@{9W!L<%iK*JE4`M9=6>tJO{W7 z^F64$mZNbj8b`!(p7o*vNAzpsu?G(2tTUxDojw@sCMv|=15+i?04{0fKdATb(}JjT?cEEsJYcq z{IrH^^t@J0(_voxDa#>}^RtWPUbWTg{RK60st`YMlEi;AGr;s6Z&vR)+lXVJuBxk! zpzQ~|7<=9&D(=BZuT%Kr8Q05o06X~N0FCqMt|^|tx?jqg$^%IEd|c$lE8>{n zVl+up+tQvJMr^bv=IvD%S46XWy;^&)d-<+RsC#H5pHDjfp0Ne%d6myjd@t8uHgRf9 zN3gff3>zRRC9XYPjXo#w=`ud~*f+H;OB1ne$I^&WJx7Zqgr>%$F@|^Uu4N#nfZz6# z>}#`jGSv2l(3_@z$%&S#bI7Ya%#y0L`j>N$=50Oy&9rzF(pLVLl%QaShy$S!HPtV= zonz?`0>@0%yf>NaanS;KNS2|aJK%BBN#u+Z@?WsG>9vMra26cv#(sGJN)TwNW{mBN zwExoA_UHuCSTs9kpjq-c&;Npmz324_-Q7HESb4@Qm>+<$+yOD%HtG;ErW^F39ju>g zqTp(0?~qw@*(=Hl$-H;#BVO4+^#alh*8|$fV~e~U#ZBHI8!}qP3OLUr(Xnd$0%dv+ z!s*JGI8O|YIaP4z754oTV2m0yvnhFy?VN9xT2&1O7Tu^kn(!ooR8*1ckf-22nh)pL zDr9ZABG*exbF#}x$!!nU8~mNK(aTVYYSrOqS^ZBMG6-BVdGiXcq-yEq&Z7}JaGQ;G zV4V7<%Z>D10@EpM-cbL~ftC-wYR2WyPgdmd5`_dx@T;q%v84;`-Mcpz){u7q-aq=mc>0d~-GkHZLs+>U${D zmI>c;F|w3@dPRs3v9*L0{L$gu6bHkbZs6Dt49=4lx;~4X-v^Ibi#G5%q>dU2O4hCL zepd{1d|JJHrfw=xcECId@XE``V=*SpwrWA0t(dkoIYQ~=UyFhzImhbC$FZGFw!#n3 z*Jo5CgUJ1J+U8vYi(rF@F6yr08kOqjpG5BxP7>)e8Qx@JD@;2N-T>pAiHDWV!0;LD zW8J^&+is%q=--cV+Hq}yHD<&?=;qs2#Lkz3=ix951f`1>DgPL2+bh$j!mJCAIltl0 zIo#uE&^+L1;?81<_+$E0gtQv052Z9P*?urUXA=xn-46(cA>)P81j;`HxwVmZ@=LQy z!7P#pM^9{-vt{QT{8pP`^GV%q+c9l)dN1iF|0}JGW;0~7-O^*vxvb|?V^nkQns?4O zcD#(!DM3diXHhlksmf31OB}#*f{Sz1VCCsj0?YL?b=6gF=t}jNVB3<#80ewKtfx%e zxR<^|8xUZTMXm5<@Kv;GqgZa(oZ5#E`cBM4+Fp`Xkyx1Oo7wpHyPfb>r?1>Fj*~g$ zHMhaQ*hMW+@Z*j^O1@nxHt6W}@#n=i{bJ9&8zr@|n-Y}POxz4=tMAkLSC1@5I23F{ zfudaGg9PMvG=rvT?}rC*WFO~IGlI;=DzL(?MP3!^Iu$h^#e|2*oRtNUI~`2vE~G8{;;rb4 z1G_m>FMQ?(Lf~h#+4IbR9OMgY5fm;yQ|-m{%1p^=JZ6aL{g}QXzSgtH4-LdEWUbe~ zYgXyqAEGvf-L=9zAAJw3A%aN+Bzh;i~di~2}N@Br`r}6b2FKMxgQ9m9?yn&cFNVMJ4X$9$s{)$8$6XQ1 z_zJ|MZpwrGx4Kq|Z)ew;SnZq??Y5^9&eR4zFrP#C^AwxbvF(m)$vnuDaYl>Q@LFEe z!~UCandPMIKW0HhSobo$J)s3jRNbE46pKktb)?48fpVL>nSFRgW#`DDDO2dl?nb5#r7mynyE_DpmHuzvFe;mM# zON6xVp)Tul#1*fL<6+-^b*tzZnebtQ<@7sWbmB-yvr&OaVCHMqHQNZbpZm}E<{_pK z^q}qgn2G_^9PXXA?S?bN?7xI9ju?tpK+qdAa5b;D_fyu^S--~SCSlhNSdI|Y6seUg zYQJ?ySTWngNMoh1OwPH3-~P1$a1|csmdCh$tKhw+S^u-+HExuD`$@39ED1NqCfW;YWokGvEtP3iI^5ogTw2Ub2Y5PoG; z@laR~F2=^V82kE$?29Y!Y@vfz@=lU}*X39_c_^aV5ja~?#$AI;AYG+NoJ9qnYISYIW9CNDD*W))N+QX>y{)L(5n7(g$OCRZ zt#4HG9?7X>ke(AY@j@#J z?E4kP?r;9wUL968uuIgHW?p|?MlsgRIM1*dbI(_$O?78KUl5n600ahI@Piih3ip zGAI~tW%@5%r+2Sui!HUO$Q$?BQ9Gf@T3zSX(?0p5jn*rT@MvJg9Dlfa(B%)>uvd{# z*1pR44~tC@Pu#5kkffc8eUYx9%0D&g_%Jox<1YoR%N*7Zc$+E+Bk;Y-R_fohjh=E( z6}Wrm@E~gv={OksGs{>Nq6(UbX@S@H`+OI z{6(V`L=9r{lOWA>p38!>egW9Z*=vjBqx=ZVE7Nd#`|WR*(&{(bNlhJW5{|%a7HznK zXEyE9K;QMbTKE)pU3{F9q(C0q9ye~GP;uRsTfwF<6CCZEZ5{UN_4A{+npOWZVDm(e zK_{9`gw0YOZ~9PFBz8#_QTvd$t!i+Y0j(aNdCn@%-*;|nM&J_-MPn91r@LM~c6<$a zVw+Db{P`&u1({s{+85ImQ3KqzYH|R&=hKE#KFJQuYO5wjA7=hF>(eNfM%s8`(COvi z^7n@@jL>?BGI?!jz6L7q)h{mVT|;!f4de@p&J~tqn?6R{_OOX$-{4dU2OD8}nVg-h z)fz6LMYT&B9COr7z$u#Z&delryU7#IZhFjyb@IsUeP`ZU!{Lj&qC=VN%3D_E0aMyf zwQ*O_D!5~gFM+<-doVjLsOu#0!M(G^)m#X$n)a?R7u`8=v-@CNlHUx!M2&qhb5ir zABV>fpTnjg?XC-Zzga!^Xbc)>VY+*xl;OW5--iIKQf<4EfX!5s?!fQ5OLn zN4dm*W8IGCLROCU{634Ydvi)5n(e7!ep$dRRAH^TK(z!dsP*M}jjx1xSuke<)A7&n zC8`}m40WaDhS#2E>2lz_zV^ z;Od&gZiy;i?x?y;7|`XpBDAS{1+Min{8`AMBbysp6QtL|b*$_y`&OMkY38VMYO}*X zXpKqaH0uygR&l=914Fb6!hKD9W1;hGbP4|MZ`7ys(JOpT7KH621KDSf*fE(afeZwqXEFd#chfO*RHnhQ+W-~|dt`m!r4N$oAL zb>%xQZ6^$G)wYXFz2lNI%K1(r*xpy7TIpgH@TpMXm%h8|N-1FL{IyVL5Xi>mJ1@U# z-A%18i8(I1C0bS)g!i3tqphuZ#;Q?^!2TDbx@uy1Sf}^Sg`zJEe>En9@N%zX0Y029kLI$3y}Jk8~#-hGRhiyv-^>^P_cg>V=9JNC$;#_ zgG~k1uQ?c{c zf20?K1B1t_8Jr>b!hc#$^L&)eXqN_u)Jij1z+tNZJPgN#zAic5S2(g!1ZuD$a-LMF z=vac~_l<|ms=&|QXTmOpVUVmJo0PiM2MDhz1#&NsEpE;;!r`0!(uTBFh^c%8sn~$N z(HUFC@P$5@Ho2#b?^1sIeq*$UO!L}q_?%*FodPbR=ULf#qYVk>bE)Y5je=6r+f-ar zc!;Rq*aiGtPa_DnRTsJ6-U|ccTF!j90(vef?g)H;7W)<78WX=2K>k`Sz#X*t>$XYc z28eeMfFHXDw}M!eZYFRPDc&6}dpBM0eD=)vdwPxqXWl$g=Bey|@(1k&+FDqQrHNe3 z`%>}CGHS?UNZiFdzWLyxOX@8Iemj2~O$T`aPFb}x~W?!4~Ws4P|UHXI%m5sw|uJrSR zz7NAt<1o|h>uA`va-&t$8_evR@YvA`)U;xWFa6b(K*M5yEimGkZd-20-qoxX=ux_u z*M7De9B>c`Pt*Nu7shV&XCO5AYBFVM<`zxP2qLGvC5}>oble0I&*zgfxj(YipeKDq zCaSdUnn}Pl7CBRY7>E@d@Fls5 zo}2}KWU=w6ODJROsKx{0fre@|Lq@Axr**?rBc~yul$#4n?rKS(a=j=!kuHK!^(vuF zcO&)A8Wx)5nC&sKmuGIgqAcp46n{{vm7&y$)T8)!ml|DOVh&eszGbPB@`!ys%@!J! zObPkH9_R_1@}J02QdF!9-mpy-YsO8I$5K%xh`0h>!^E9dOi4*A)CkiWM$`oWg#Ny? z2#M68P~MSMDiBUKUj0?v?Xn%KV3)eXcBXIBO9rzBVSTkTRc07gmuHWMRKIV!w}v&a zuz)1WpS5zEm%Kyh;QfC1v}RJKOxNt3JK;-_!YPx8O9=9MUgEV5qmuPl=u4BDV{L?t zyg0<>_jMq$eHAW4*ij51BQcw7yuME|)2h@8b{iF`eu?RDfu%wFq2;jo;K=*V4h>?l zFjwB)gt&hf7nGTH$TT7y9@8#2ZTJ1vmw)8Qc$g>SQvdwi>Xh$70v~by_h*7XgmE+q z%vnSbJ~y#=^1l99HJZ2rGYa(d?eXt>4vk^rSH6_Eyo1fx>@7`;gykl75nP9vnW+CP zWIcSz#vJ^jQ?H%}k7AZFxjvAS%xC?cROi{X(E(oUz%FE75o^?cNi_u*g|C{{*CyVw zvGSepx9l`$v@un81&9y-t^zCHxGepOZO?jV8qGp?B-3(g8u{Ea-A3A78TiR#t0b$C za8xMQ6IYBxHE`?cCsZPH0^1;QAIL7oZ0n4Ye1Dw1jm2?KSdt~sn3h+PRctYtuGiHH zOav!%dAm`MEQjZyWP7@ue^RQ|AHpo{&uiDs@&b>&dKJ`vzm}iqOXJ0O^yU^R0+EY4 zWRggD1C_y(5iul=C~f#n)8S=4g+Z2(9B_K2sMp0v$ItyYPAfrjaOGe^p1i$785jT} zAB$?Y8E)=sEXTib_=9??lf#hJKz;C@CWeh~ImK=gZIjIlL52T;4V_i5iBoMjHnPq9 z`DGpP$^(9{AV{^#GPB4stsujfHcT0vEl4@;ac1K}KK-!n z%a&3+LcxbO!udBe{?0WC3s`hFKIy|WG3>F#orb1?-!~+NPwq{j)DC3DPgGBTJTL22 znFdgWC$G2& zmR~2Xq=;XouE3AwoaaVo&VSN(>=LGe3ycBc9_Y|lbJ8HTApuqx57W@Qm&j6mvZ!=W z>}kz5*b`0>!#UZ2?NAh@CuaB3po%2T=Xg%ZLmKa_M)vF(2Ms?52UMP^=a|N89;$R> zKR2aL&xg6AE;k05<-f&7{O>W4Iz|L$ViqzB?6fDO@#f{FSL0h^dB2LL@nuYZ=*DW? zZIG@>d85%^eX9!ghX?ToKxYjXv|(}>ITrpVXCd)Nn!14%x;5hdAF7T8R?<8E;?L@#@xf3V>(cjG%ZO0&VhlfaRPS= zV2(A2qRueE{+}$!}G40y_~1#Gyz(1 ztF**m;s>?pC_+WQipVIAID}UT8CiPxp9IP9Ui-VNq@Y3?{Mew%1J>=ccMYTLuJrPO|%q`}o+dJr1uMam)QAd`24*NQO3a z*7C-GV_jOS_E7_7?tsCN@^~vwRp+VV0OOJakN-m|iKJU@xKZ?;j z*0aw)&VL1^CZ8h3yFWr}M7(`y8?N>$((^6$8&7nMFo@$sjer3JX1MS5m8amAJNLv5 zUpl>!nLq}i1nm^w$!IsUqq*bMts?a%Z~*_ALJ#x{egmF34?8qqN?@Q3h3JnNY4X}0 z@iXAL|A>Tu^0#4_fv__rZ>=H-TKkw;h*?resf&o_5itva*K2~niwa}_9)j@IB#NsY z2fFv?8fF#<#%|+m5iUZy!cM&RM1Z|LX zS98xO`&A|XD(zR+ZmDln#2G$uAm^HDC);B@=fNC7TL1~(@BL@HdPQo?$N5%Gd##&5 zZzs;_vjC=7s+P+)zLzo7OA6fsQpYEG$h3`!TBXIvQs5cii}-cFzMyqK97K3KJjWMf zC;IcNyeCE-xPW#0=syf+quIR!yF-o;%@ob2#pCY2l_YPL*LjHdwQ1(EMWoM`IqE$x zfy$cKLqfk=JGOB_o#NwH?4*6kCASe~ilM>NRwjIa(WbVGU?Y6KL+U`2!z*K!ADi;Q zhki>uTAmFvIJQ+RN1ZDOab4P|~j6&ekxU_;oaI^2ri#|C$;jZqQu zyuaOGhb9J}qW*(RHv(kBR~^7XBsbQl==5*&L_rXrkP5GR*+$|Hh@T?vjxc~v!$YhA zrArii12f0CsDC7z+c6hH;6;4qi_!VS5<=?yQGa@{hn&k{v^OA;txxgw?Jf~2S1d{) z(xlKwGBQ=Rr~LVmjyMqO?B%6>k( zK?lJupQo$jr1coDFo>HjIBB=AUwL7IGc91UMPa*ma0Am&ikm!bXDKsxia_R zQ7PGmmTv{FA_cSqN^G8a>yH5mSJ6qjzhQxt(OdBBeCI`y_R6uDUqQkZb)w>6v*qRa zUT(LPJmwtUHrG=z0C@#^X!9%S)c1fT2w`>6laxzwc3jg7XYxQZP$ui0< z%S+RsN8R}UE(I=KCjb-%#I2-ZGPe=O0*PRom8!Ymz^;*MeRruQ=AfN1eU{Y&QsW`w zIAFWI-;@5g*;PeJloGLAC3#+e3GbD9vrsZ7Q3)-BnrM$p9jsd0enH-wBz3E*yjT{w>E zb^cZu0bJ(?%-d%twiF>RRu$H0Dd6cg97M~6Vw%3?u(h-Ra*(k3f6J??iNOH@cku{v zxkt}1;wMf4@DG~2p+hY!=C@(z^D_{G&HK=G56ON!GS#=nNg!B5WhWt^Z8~9m1v1;1 z+7{P=B|ALPjC@AmMNwB8sTD*m6_E zi_ziqTo^#en?zRJf}A!@e0{$*L!U16V*2w1uql z5g;a>=lIm&YNQY5rN_3n{|mF8=O1nZ;qaVFUvjpE2K3WrD*vH-cRnFEK9t6%CpXL> zUb!E`h>fwiM6~F`)#hxgym;_lb^a&+Kswuuh@upAfcSw|?% zBZcKOR4UG*se0j=(BU-0g~yCf8v^{kU=J4iU`eA606?x+Rlo>v@9o@T47$O~qB9(Z zt8o?5=f(xf*@d>EzD(P~B57`&jMe{FopOX45?(LeQE?y_^cI0r!CzC?sMz4E@PYg1 zG^e!BSAH@N(*LoekM1Yn`R3q;7fL&75P1XW!eHl2W6OZS!7kmItuczIauZjAau)ls z{R7oHuF^;iT^|&}Q!W)WE1O)}LmOxzy=W|V`9_W~7WaHEG_T`+r{Z=U!0W|y8(*13WU{QB@Z=qm(|$$)t%YyytGxzu0UDJim$CeC<<;EVN*w zwYVSK8M2r3@-CV)g$74(NaDNMi>gKpa%MMQMlO+GXizgOzY{F(rb2%hbCWx+!4Z9&*@ znSA~W(Y3T6&0*7d30H8Nz4LhDmn_;Kvz(&5HlKl+KcQC+=IbE>pc8tDU5iSh!@{yU z;G#!b2WcX+cjp&Mau)zz%nMDEc4s@t(lexJJc^=E<+_~sk?8*i!8H1*I=A_`Gu^{a zHC7Ti|90ftXV3c$ZG3%E{>gqC#r%T#BDM83J(o~6S*sj>XxHSkU#$?tL8Vph{-4B@ zdoY#ioQTl1aRHLo++a|RpuMT!Xa4gR<~bTeDrXK!c-Vzo zx0XQxx`nqI=v~S1yKBw6J9}4=fQ{Xpni}+(l*YOW0t|Vq$~2D_PDFNf<4<~j`CC=X zV>wueR_Hu#q$sX?ung5@E!am*8@z9#({Ul8J`^zB&`KIm4r;r4tIGD$g`jws)mw4L z|ECwFe3bJmq1(wh)ToaJ#ErG}4cj{CY9-R2sT0|Yg`jLN^?URiVif}pNllw#ZZeMj z8a+IVANv1yuQNXc7Zmm8Co@geSY!amS^GDK%PT0XgT4w3JZ*gr2jc($Z@Kp`$_<>0 z9HDdTkBU2J&7`NI3n}U!Y9f4=ee)ye>R(5Plw1d>e^PJd)a`wsaUV$NqThom{dVjqFB(65Fj6!}P12Dop)d8=Ww{-FZ%fl$fYL>d7qCH5snL zg4dHZOP7CE4hehP!$v|U`<=ftK+GvCPc{{NH^~!7!TZlO(?ar(YKrS z*SQMXPxO)5B1?f7b5PqzsGsC4!?4zr8Bbbxn$o8Kkt$Th*qT-U1J7eYs ze*G}=kl@X%l@`ug>RN^UjoG6g)pYmj&!2le>ctFqo9P743DpsxTRp3-wfFZ~|23t} z-Z205(eXV~#(N1(pMeZy=Wuegm05Dnfb_l_nA&2fGt-@ypnE=Gs0-$U<4-r7#{{3S z$(Fb^ylOH?aYE)K+PMMZu!wa zKdP|!a@IK_D5$f|hJnqxBdz48WRrjUJ4!Ch^Q``WwL=3Xe9y>p|Ke&u_tp0k=&0Gd z-w*jA*MQJk$`dw4_XOH$L&U`BtK&3DBl)N6@60l4c|4F}rx){IfDeJ0le6Wrdi~5t z3fZFeFW$Gzt2L)Y@*5{Dra4W_L2I$4^J+}{{2M}$8_YyKsyWM3kc<1t*7p?A zJv!O#5g&ZDKF@01yndg`d5Mp==+hL@Ty1<9Cll&m{|1F1a!lPbNa=0;v4A{sCik3e zTAJ5^HtyzWbw=j!1I-ru+&5>#mKT}5KtWrf%)@wr@YiCy%i z)hPp*obTSKl?os~plxE9JM=B)vPHk^n0NYd8L&wFPw&qS50fnpeHd5FGy&+3E$wmv zHv{aQ7YWzBc%VB`JhxU~2}*=Ap9nv02^=2txyE9DJ$4CFFpPSX15NdN+WaS!yFk#V zIZ(a0U@P}j1VFV7BbR&#&W}3G1WG21<#d$4`{KqX<&}%d|Bj4-H!d8hmsq6SXph7@ znIF;oT_@qrG`WhgKVn{oFD#c5yVH4eeKfk0^!9zDSyp`IUNd*N@+Rko+t_NvUwm$J zzmzha=Kk0{?;@h;a)XW{D@vheApnBEG(1d&LL)VX#9EKq76=YRt<=qaDqa%T zfapqPUkb-Bn{Qi1h8J@~_<_cwTP3uV@NKlt-B~9yuIFu$l%gqNNQA|Mdzk|#L&R-b zWi{62Ey@#AU~9kUWonOFmpr;xNv4Hqfc={=ka|(RqxARTNg&0Cr)l9C6F$G1@^ z!Mj9$i|X>~3#cW>jlG?GN?}O1XAq?XboHw#(3lX~Bmt+M{TKQR07HvO787@_^=x^# z3oVEy9)%7zy52Bh*XVzgIh2!%!-%F%cLH0i&svgpYIBzka!T4~2^op>SIP-!+MW={ z;|hy3of8ZBoIW6))n3bLIgZwZmukDqH$2)9fn*h z+0^cSrQ^~xtPlT%vpP3i4_+zViF(SEuQKSSd!sUW-_Yeh@WHB|12dbuT0e2rzT|7| z#rD1D6ZIT#1<832Q0JYWChc_W!&)BWGIgF-Qihd(^mho-q49#Xl-?!vO@RHa4h3W5 z{3PSUnScL+o*Abo-5P-!+uiJ3omp`3TA1i`V|5;n)!4bfIv-A1uAc+j>dhYO(@5&1 zlhe7}u|iDkwaQ1uH1B)PKfZS`YUafRDqDWc6Gu~aD~?Ln?HXtWRmoSJdkdN;zI3? zuA{{p{Em_(AUNWd12<-t7u%@=0SEBH;cZ)AAba4@H$)y6C6GN;*qa}XVZmbEBbc%u zo9!#`QvwnY`E8;l6U1H+X?1Q|q$57Z--x<4JhYK|J7=;~=*5QCe%R@Knu0+ZGywjkv`;!SY z>tg5rxzDmXlJKf`z?Fxu=a7xa*s8%TVy|zgVf4fnjWD-%W~z9NR1IgbTuo;8G&nGbIauGe^j#|Kkb={Pm{|%c~0v z99MU%F%MmS;LsV|WffNrUh-%u)^?e#e1{cloEOs+x0>Yp`k3fjLKUeM-(oHWI+*ga z`R7LDWDvt|zBS{)kd=%$Q7}_h8`t(r++?0)Ce22bT1f5=^aDQs{FH`>n#N8ho3ByK zRt0CE@1Vn>%4t>M2Ls}*F@c-z(U(25{@rMdIFg=7ZH1b!`{xNz@8@J+ir+J}u=TfA zCH!9MD1d=3?GNBquA4ps#dhyl?TxlW+pQtVjhPhQ`C;$zRJFrQuw9pp$ftpszYTDO zy&u?~p(qVpH5uXKaI@|uc;*A}XTA%TtQ|KgoZ0A^ZRoARo;Y`rWbIS`L|k09@1W{B zc>2KgI@jvDho5&^WoDF&$}&uC7NDJap7mB@WhVp(pfs53I(kk9JL{2{9M#c zh1$q1Bsa8Sq>FUuy}^*RPq$TI__J5R`cmHSfW-m*ttY%KZ=gGZ@vF2I;pP|p0y*y- zL}Na8s>v0U?)>yGo#!g(&0Y)!b{acpFkBoN$NtE;PRc(k6NtQlwLnzo!3kIZOMQb4 zTs1w0JU|2C*i-zJ$Y)tncIwNiq>m@@50~%_m1_UWZQ(`QmV;H(r?|e>%xw1#L?(}; zbT4@AQLZ^9P~)JzePmpEtxot?Dc*uLH@-M-sf-!`-{EgK*AyuJAC3tKwWBZ)e}A#3 zt?&2j=4V-?N<1=8jT$TScq2qbgvsZI_*48#;qjvsv$Enh%Jhdled#g1+eLT_&t#1|h6B~zLo?D}YXSp=-?)Fco`GzS|;p0zW zPcysq5devacivBCME|p?2_v=(92kcUt8&(%835S5ubI?Ej$mXziS*8R4gIM>Iy=1K zU*I%+MFXuy9-QZS8Ehyqj_WfO##}!vTFbX*LV63h`2_GtiGR=>w)gA0%&id~ZbAZA z;JBh<;#`lnco<;%+M64|IoJsN7LgI;AqJO}aw-6_pNthQmztfAXpy+j8wbcR=U*=) z)mA%6584Ud&z>sOZF99-FIk?y0RX#V<7Qi+d_+ZiKM8U}A#F1O&d~80Y>!q6<4(tj zDWJ{|(zbyCkXFLxR~Qe5AMld?YkXa% z=(upcgtDKZPhh9)TmLeWMMLw;LyAoIxp65@3usz2mU^@y(DxmQq&E?x*u;JqC)Hzg zsr?G~#`zKl&7EwQ%$GD_E`-&fYmJC`Ir*PJlw{{dWRPT$c3>P8oeg)}zQku;Z4zTr zk%Gu@>aOI2a`%|7{C+d=6x4k@JqUXZQrz70_Sk&XxCYon5ZIlI3g{rEdVbhcLQdwe zp@ic3jX}34bQ8>Ua%JBBQm4T+IQz*Rx{n!FsZpubY7)l^^KQ7aK2S=|m4AHEN#LQR zn|Li)Sg&aLawRM}>kN zp8mXMfx9X@cH?2Mv}J>Et|l>KIddk#>hEoL;N!?t=X8hqvbg!YmcCd@yk>Jhv4{_> zb7t9NtceYIGDHo-McW7X`<6{ZS*p#GeQ?|=jg$@JL!^E1*QM)4*=Ho>ca8udJ=@EH zixRA(Y|@CTUvtywNG-r5o546wE4NUNU3HUE=($_GM=WW3+fJQ(yY*tPd}&%Q)GD6t zJtbQw?rAziZ4}xK&Uo%r^dMjFhvG|?iv?If=z~X(^Q8Ei;_743-9GPTY*YZFjIt7L z7%~OFgw&?)aeaH&l3p5ciO_<<}|g8TpzzhURb{* zmQ{S)^m&!CiPh^(^YCUb_upf1TIy4@UHs1ekI0MNFUNN(74&jvpZVINLB~qt2IVyv z)493lUY2{Nw+zbe3^VzVG`V3@K?TX^>D_KtT{jDJ3DP zAP$gjMH)6jNofI*lx~pj0V*jC(u{5xJs8{WKcDaK|KxtM`_W$4&YkCZ9mny;8-VTb z8`)Radps)QCBJr?6`jm^s^CGN7;6jn{QxdG^F`-m1Gg6V{F2vk`TT(qzYTX^RTu8Z z^Ox_mr9T`LTB`g|70PuXfAYA zf0rNHh|WIK#88sQE*60^WoPltJxA;H0Wnoa4@OaL-+2Z3oh|q}o{jd4lfIsIZO3=s zDzT}-083pOB_^G^i0{YiKM5E&nC}J_x8Z^FJxY0I)_9s4^G7{SO9qirOuM8f_Fk_> z8_oSILkjUHKm6T>(C3f#aquV5y{Xgv7`#O$r^Bfa6%DAEIf$UJh)b`Z`4oCxbUL;Gh#&X`sih=su_i`zV)428nAzN>c zF^I)0Ji;p;Hc?eSi{nJ!jU$VpXZVWxA$^@b_y-R5=|>f`Cu_dMtVh1TzT@B&SWI3B zayJ-7xwn!J;3=`^-3s#Ti5}+j-H*_M{xy4#Ud}jN$(aT>ms?vfSEb%S@1{{v+r-eb z7?Mc3FLehY?9wN2M4P;JTTG|^3`;bZq3I|lleiVBFV~FoJflNl&zcnz1vN{t)O*yK z%nlK-`G0>wA=nnQH?dydZQKs3f`4)$QsbmF3Jx!He*qpv_a zA<&{R;TEnU{H!{^#K;(Pa%6P7y9gaO)HF$xKe&3MMmY|`L+4<1a5TOrt-+${_O7mn z(PLsFsogq?`_o6)ERAv9;H7DftBIBIgMznWZ9bDO++o#9BujiD zz%~6I*^WN5g>T{U$1R6#)a!Whq=bZ@xungNsisOIW9317@m*f;u7YFAfudsj@YfnD zZJwS{Vulr7Dx(hrD8ycT2;ztIPro#a$xf{2qH@@|sLCO??}H-SAwrUp?5sW!H~rW2 zZs~QwP0SRlv-8_bTI7s0BJ#};u5e*d98xFTZOy`k=dh_Y7p2Tm0CPUgzH<#B;8S`>&`BnyvHxY!AI~r3UaQPuEog1o7QKys?d5d@Al1ke_`C~Azfv4wrP<(>0}`sA`$DXO>7pN9Xs~ZIS#Nc=l9K z%r1|L`QarsYkE@EwPK7E)YJbw<8l8q+Bc-2A(d1OIDo7NP@|FRkaJ)+F~{=Ia76sa zc5B4SLWir+RbDjwn=3r&y5A$B+o`Djr6M&ykl8IjV*>mek8xfxm=eJw6+9ps&7U{>mLFJ%hN+3Ho)?)i^1{(<2-dsVQ+ z$VbecV#AipDV&r}W9*=jL?3fdW`etjMUFzBs6+7WiCCcm>>r3X^w}_yD;YB|t=-#dWD6WSe1Du9 zdcAEDI3M`y`Bs)&D~h`0UwB3Dhn&zf(FI30tq zoi$Cl9V~&yRvq(2wQkb)HHL*-JBBWk2% z?*(Y{fWE3vJ*vp@uA|F}qZfnx2UCXv=yUJC5dUI$Q#(xkkcm_7hYV1(^{HvAthTWX zdg7b-N{^pZ`Ima4@j8Uh(zeOj?{IQEjVPoxAdxjgz4?((dx)5>O9q75|0>WE&$uaF z;7a&|cZiR-&|cU0qf+Qz1MJH^34`b&G9f62k6_;8@RBL0+Zeg=uDA z+&wNUgEFvkR^xo`H6g?pKuSoVt*DchX`WT5yv!vMD%zd&!vMAvi>+cHV)F zNU2jI%UdAxCJ~DDpLHGzjP4=NDW0NO#H9(E5hqy0@I;rIocf&n1RknQC%p2qu>Jh` zngDaFrStNYOYZm4z0Sc@2^Tv|f`~7G%s*Yq>+YRliZ7S!+-joWQmFd}E0z6x)j^-{ zK+5ttfG^Fo3eoj0U&va`H?%9&_rucyzeLUP-+yq(0Lh1->Svc^($gEV>j$C9G6vV=eHvI(XRRh=esA ztFu4%PwIN&{a>C@vUmM4D6 zpQ1=VP)d=ELke4ok#etj8|K~nR{y4f)B^j-sGi4D;c>A1|H%}_1vEl^t$ zF|?;}BYsfPES|iF%v^h)_xIY0C@}ANaI#3LouZ4R>cQbP=6BBOtUqex4Gq@?!_(a# zr7KA2fq^pI3t8?7puNqDE^t{Rl5qTSqN8NsD5T`;`<{rRecMC}<>RBFvsJ1Md^Q09 z{z_+bbL;qH_EErcAE#8n(P2#fL|gttf(M7l&iq66%Jr&4nz6W4@aFSYHAvtS{8i5| ztzSrDNGfoyq146H^f~X|o5ys@OW!6(29Is86b{{FvT9C%|0x6@0^R7g>S*#PD$7Bp za2uSwnFGJzG#;=`=LiE{(rMiK_45RNnkQM&y^E@a!=L*g9)5XGqo4hkZtxPqb|7v{ zoI+SDQMC&D`gMJt4SoWG8&!;;4ykJk?l#ol&XfKJz;_EwT-|V^Jo@HWP{TS{B^-G{ z`NZ>m7J!**FnI#OtXiN7n%WnIQv&hhfS`i{sv)t`yK(94KXx|C#V%h;V{Ac2RfxMxy z|DndWBqx4fW@A-1jh1+*uUl016`8Klhlmz779GCUVv_K= z1Q9hK5aelw7=8>Gr;2(%fSGnLPKpXB`@ob5<9B>--j$E&Bf8W98UPCNwJVfe_K7bq zx%0OPSy4j99h-Q`4!wOlXgHGMp;nmKRNU42MCx6*(NA9nd(K(Tlao{?Y)hzpMOUXg zSk&P`^~=#Fvc?5>dtIJI#d}$ErD8|7XrZmps7I=1Z`$hUxk|}1R1P|h#Ac#H`c0xO z?gSNd;HnM70r}HBNJTjM7KfCIQne05$l5%uS|vEmtjJJq>I3B~6=1J^STJh%bAwoH zu9B{Hh|$#mHpiHv!1Of&UdGhgPG-~2pi0SPc8kbwD z7<*%dGgftuRUAxhZVR!9o2+GZ%*JCm=jfHJr4}=l6^(K10rUy#V}97mYPF!}SEo}X zoy2ZN^A{x_di#Z@zoN41^=eYTIQ)nq06pQN=()O2hZ&sCe&2XW{SdECi~o#>NdY^~ zRL#jR-FCCK$wTb%wr8eDPiE!OJZCI6eBok?{~FuXk^Sj*=_MXfRR13Hh3c&q;-Y;V zeXYO6d3}Wc`t3*VEf$dVobRVhPdp4tNDzdeX%p)%1p^5Y>aDA4FH!stpu=#fX@+5K`krcv; z9W@O!dS7od2NcPxY6Wv=NG{M_0^jdSmCOSWs9k(T2$bE6P$aTK!{_G4@HaO5b{6Rw zzcS^mAX?5?D#*VyOJI#8B_MchsOmoaT+vBu>@BAsbXZ_d)S3@hw(Fj1 zWR{{(9g>;f2U4+94;J@yfore0)>^ODtKlxlOe`a*!3>$BTNi4V$@@BXiKfv!CV?;6Ii$v ziO8{Ni|I~khgl7@0(wPK>~?|0xZ=J?9}+X0`~bW!bmvE0c3VmLk`o-5O-f%|;=!*o z6M`7PL?nt7<_ht~mlYD^Zs25ZX?8QuwySEQ?BY~Nu9jWSPfS%(7kO%TuKk>OUXj>z z{lVH9q6wk|;~$NhyGCTMZ{t&AjSF^g@OL`ej05u6?3ljKpI_Yg*()xb6*_K7MC7QQ zI`CC6FX?C}_1M)0bKq}l^@eSD)Rjz965YRSWAH8~I8TmCGmz!a9fxBPUvU+y)(w%Z#NV7_`|mB4_NX%EW1k;h9NQ#pu9jsbU^ zd_ch7SVi~+VeTir2@;DwnyduYl-s|vODZY+3;OkJ)dr>qW-vOLup6I`{DUJACLTBu z_M=qFa|h%*x_Ghz;L+(sIt2YjgLZfVGEX27c@K_Kt_^ORGQF|sEa?@MK4kjo#OMSm zBj!FTacaU{tarQ%&NL6)yRbPs>Q2PLkpJMzc`PbI#L4-I!7^1Uux}%H*17~(D2H6) zw7X|}!HkMKB0u&ZSO97mx>q0ETJPrJlP%l1jRQy=>w0Tdz_XvTX zt5(t>mchXuI;0lNE0>K)(8YlaYWCciK1DIz6nnpDbqI2lQPN8O?_&1Px`kvnrXA<~ zmkw)buBq5;OLgv?pj=AY2Ic_JB4%MaFKqW^AtHyPYS`&7W#ZdcZ{7tjxNQnzTMtzC z+|@nI8rTUIe!oq_6$duxUyoc(g0hPJ3?XN<98Fb-r7qdpLnyFio3-%|2-LvJbfv`d zGsr|=A$u@d=uwwDC{D{jn8|`0aup9nKmlAQUKg}VEF30m|AoO&f(9K3O|y}4bQ)9c zk{mkoV}{SM0WNMneK8!er1A>UM=}y(s{C=lM=~q2*Q<84FG*%OA{~G$C12fkUGrr~ zmGuJTx}h`D%I_ni-|zYtaYYh>^*JW2ZBDl6JJN{>~1hf@7kJzS)$M(9=Pu`r0S zUPsJdPk7e_*AF-UrIS_y^~zwgs}X#wamDdk#!#cIjcRl$V}=#C26Vgm-2~~=70LlJ zw|M2sGG->{NlttcesWtxs!=k$49;zm5L|ECb^w}|>cDaT!xQL3Nc2VRh5d-&e?TiZhBdUv(=8U#Db0lJmFg^Wrj8ybYb_X0o zO7T)YuE#Vx&2Y)&niTVcUlxK?OqiUBbZFu|BMO5!3koY&N}IIfv{S-E!o@G`EljT2pAm!%3+J-hVxLg= z*Fr}g4_pHuUdKiz4lFm=4(xOx#YCJII@|v|*6_oFOYW6vU0>*?G)m(GyJ_Ss4VpG{ zSA>qEfi)oZ?l0-xrRxJD zQjh?6WA!GF^axi$dXT~v{LN64KH+T`?kZyH5Tc$cg0SM;G0Y;3{yGZ?RKFD9O8BKJ z-mIeNq+34tM2`-hqF@BT!dItdWw=X+Qase(%nDYVv)FeZk}p z=(Q0-bRO6QY8XhW;H_vR$=SUOi-~vhAS2I978rOJk;0xYkiuCIeQ*bZKMUx#t)Qob z6MRY1)i%ZYarg~C&)~c#S8pD*(o;{_m*6<+wsyCE5>JA2x_8(-SdM!^(bLZd#TOe` zzlhO7h&%|ZF1g4d1D2infco)d#g20rLJ}<5fZ#*v;vKPD04lA+ZUlVwwEG3^*Df!5 z7TZv}homoVtL99A?b@Z~cum#_1b2!#BJ*qo(g;pndQGGTZL4O42$LizOIrA{WTxvY zOe(KmNvp0byvhW?u7~C`Os`e~zOxKW#Jb)?;5&nZ*#!AK{``X92Xc<)nckn|orMDb zX1%za&R1}JNa<}Y#|j<@>4FS=n1lldsirI?bp>y5E6Ce9|+*} zR#>Wp**EDIGbd7~f~s~gRd4^(f8R`ZTOn*?*0T}8>VK;Z30jHo(SK^OaQ#NeUtu(| z6;JWfQ5X1ht6gFJ98Jt&92$3x8SCdEAE@f^uz#rrZ=>UEz1~R2!xOzo08bJr<9s=w zMOGVIG`91(1zBKlm%d5b`{MPvpOYG5YoXnFC(mB}W@h^l1PTO6I+AALJG<*gBLRdh z^}Jd}+h}?)Z?y8A`vtDL)0*q5hj)p#1qB%ON4#9L`}D$b0g&1cGQ%!uxdx5sY53%E zFry5j(|!UjrKf71cE8Ud;-M?};o?u@I=EK|e^zRZEijibk8h~LN;L!=grM}V_`C&E zU!$)Gt)u>BgCYZ;QFC`-^XG86y1?pPwUv*Y{&}81kb#qXI}f!L+EN`vA#hib6A6xk#FSvJza*kn2^A8^HK-nAKgT>_2jkrk>evJEWCVFu zT5qCJ@&)y%)31q@*Dj#BD>%SBIpYo0(9BbxDWn@!a3^7(=DC;+{`n}$&UcrcryXFu z5jPC`g2dg3p`Jhk0az6<%i~AUKtTu0oZ=^+CSdl~dBs7D-e~Fdkg1 z-?Emj4(+uc=_mtep_jqJf_mvNgk0}Ke2`NU?e*&O!$R!iwF+}BylxJ!0U;0ZCl*y- z+Z@59^?Uv{1Urzw>|UOt`8oZ7oC9Ml*~SYZCq44U#V9`udF{bg#*si@%K!v(WR@Ku zBM^~FGv%csXy-=MM`;yt<#ac1E*6EFvnb9a>iq05wt$*<#*Q%TvzTDS{CB`F&w!`6 zc^^;J1UmUqNYeJ7rtt_CZVuUTRi?N^ot^+qbsndeS|O7x4iL7s9mHhJPqO+f&12l* z36+si$UyTy(`Ro>v5`U(zrR%JQ@$7I_Rqef zAxiA%ZO~XF{LKd6X54DV)+Un01eF6x`YSWLN!9c1o|Up&BHwxPD+Y9!*|H5?5rn0i}OfBg)hJhb@kTym7^ z6}CN&UYrC)oIO_AiY2izH-1CNraLH_S7RWVl*jpslt%u(;kLedC;aNWXcZJ_lxW#r zz*)crnjXi5*pjWi5Eumhxu!6Rhf~&gE!=m1308p~iFRI>JnUTXCq=lFzEPA%dpoQo z^)%p2!wfq2#MOR;*Ii~_4vzL?vZaDH{$8CR5R8MLFTwsgE@e=>-uV4Z^$tuAm!Y~E zd~tr_03gh&#vX-=ioV7i;by5TQ^S+RbT4)&BS}_=rHq>FwRvsrzY`WV=2i`_eP$)u z7O;j^3wC}T^g&XEiPXoH(k~v_1Vvo^CHvli*GxH;?j#wB{vos#I&E%*6VP|bY;&s& zUs%&geO+4y^Zu`;;z|2yoRr?QEq-SniTD*J6Hf-yuHS?IC%3wD5+LS3#Mqrj5Q3jY zAum5a(R+UiY(L4aywF8$C%N-hO%rGdivbxyMXCMEOi3#@&DWs0C#3xVqJc_pK?B|? zrn}M(Z*i7lxJBn3D3~90#5*=?S^LgnjEo6ET9Xjq3+A_& z!mi<41`9EZPvGtQ?=8(bX9!wnl|JV#FS+UcZDS>Or-SRN1yZ6nP)Hncwv2J-XwU^r zY^>V;dUT;(8LE&CSxfnonzwlSCStD}!4+RlHe9MT6b0+3Yl4e-c9~pW;!iyrhW&=B zuDJ<6KaL>(!ezXPf+1`VGWLMJ5?|FTOcg|j`!!)BZQk$3kx`aBj34HAKF_cAKDX@;~XwV?2Yfy2;^D!ijlxthrw9NDmPT^Z)+rwW~fOnt*+bAr+Z9`Xw7U= zkLo2D`TPp0y|C2}**gz~OuFkj`H8lB+n%Xcal*haOU(nRbP-2NR20jj2_Q~uco5r@ zH0v*I97M3%xt@5Uhqd*NtSw|@rigw2D+_T?JO!O1YaI>YW=>*80rLt3dC~RzP3>ak z_x&&a9gU5(E`R90bWT(m?lFG&DgSo4ml6(zhIh*y#l4aarHNt!7-@;VPh*a>CtlHm zLjS`7PeZap;K-F6JE{-~Du%*edRDSFBVw=@6SYVL5lt2NKL)QLT@E!x@l!=ZWujrO zk5(+YU)mCHSDQn+<@hHgP6*r$bZ!@2!Mlj!SHw%doq;diP=wLCMgRr2emqAP0}%b0 zGV-1d&q*~hH^tHz!Htg)jJ-2GA07?@V&75_(Erj zz)L`*wZ=wAe24AUVplAUOBr}er1YO3l;|l>w@2DYPYPu#w9BUBhK=e?RrTuKs?!l3 z$8f$pQ0Ww$f4A$d`C+mTvq#qr%9n`KN15do2MG<0H4ST1#P15FTeIDDSo^Trvs*m1 zV^lP6Df64R*022tg_RcB4OJa z?8TuYNWtVMLftemW;Y&4ECmEiq3vt|#(Ez%#f1zRepKAr>oiy22)M)EPm!D9JN*SOv52;Y@YxyDh;@hgh?EjU#Lknz)n@`5;D41Nm} zvr1dFkFmGkfdy&4lq8vpa+T7p@0KxHN$Nf~_amZ(MMI0P&s6xcC;R_9HNB7<->Ub@ zE^1Vq2NtE!LChKanM}~m@(qgV2-U^lxlQxzZ9V2+dd_{2;+W8XAY04qk&&B^L_YZ3 zBd)~NFDH_^c{?s!-Y|l1T&3|?xg!H}3^5~?fdP2~_zkRIjy@7kIYj-B0t|;xVVTMq z{N9o?a}$ktTcv{Ij-u$Kd1w~3V7%H$0^+JQG=BzJRs@q%dAf_|df$+MNl%YC z2^qY%X*`d_%b{Q@FN)f$D+MCc^)pKU^Zo3bA+D_AQHX_Aq*)Vlz>!Thp6L5w9i@79 z6|5af7~c8g=zhT6JM;R%K>;4mSrq%av{Y0HDl;@8VyK#@P@%ycbO!-TcCM?}SUMVk zXJk8d$}2))iHxPP?pnZnRv=2)Kum7b(3lBwd8&LP8a}zZtGesge=(7+)^c8AgMRLZ zg}uIU;^{Rzv}7+rF3IcJ3xjs9+$=Z({A9H0!jrcn5BNLVI@>M69)q;X()&klE&Z$s z4qGoAfw!Jh6!h!j4}w$-`ZI^oS5WXXYn%xhXY<0gMg$JX@b=FcWC0msx8LRjj7qPO zUnXW3DXtC-*?W1lMwV_Y0yd|{TSW~uQ6VLML;GvV*6Odo$3Q45@4MPRT(BP z%Lc~$DQq_*>z+x&&!k%vtys7kD4UgT%Yv4YC1N~$3V)DqQv;bR?GARxcdY#@cYRNj z-l8vqqA|ZNyfPPe3|!O6pJpKicQRfsC%zycmmxqp9T)@}C-`D;(P<$5?qd;w#hi+T zIXLVyIrWIN&-Md`73c>uFfFL=G08}`+tvUnRh+R-8;JqC?p%uIKaVQ1y^h*#62$mj z&kBpD-^}sJk!`0n_Oecs+{ByUsWN9)2pBz(GB1CFwZ6fy%B+=4Y91i>3 z7ng8W8|}Bd^!q$5Ys?}`w^CNxQ}y}f?QaHyPHynhhVOUislMGwp%K(1n(AR0)r$4i zD-5;yQ3v8!!w4~bqt?LZw4YsaPR)Kc^LUX9D|gq@9AygY4NVZPIvW0&yYG3n-ZtvpKu&~V=Q8Ij^ir*IqU(3)59w8eGgS#rduQS^>#SnXjx29Y@_S=A!lBO3uFhP z0RKvH%jETp4hX%m|yBj;aY5Y<$Sa^egb z@rKNf7w8o=#>Ec)&6XV5gc5&VJbA(h%xg6|%_FxXF-pr1xx^@F6r>Vz#)w3e-ifna z-O)-HUtS7}YoO&B-KAfmHCkfcq41OBdQ|qbhSdBqN)crJs4WvtI6K*py0z?^l`7%C z)C1W$J`m2akg4`p=bl};tiNqcIz@@O90}#$M+ zkI}G?2p6yHQx_(}V{pLGNA{kh;>J%wo4Z9x6`*GJWt|pKQR!*Yp*xRX%kVu&l|XtD za3Jz=xG!)hT$Ayq%gFVP;N{m@n8egXtcI^PZ#w$BYe4BA4P?H^8=ZG?^4bFI~*1hQn#`qD0rINhb-hqY-N!LkZFn4U-a^3FI%@V*)D!SxHJpSeL|$AQdv zy^6OzvOM2-q!)wui4$hiy_^X__7~6pzW3;{DDV>rf~6{?%;Dprfag^@{ss+w3SgOq znNRU-LImam*l%@2cTDmXtW{eNa)@X(rLGW1R!>AWr+_bAc-j(0-Y-mvfuai<%DJa( z&NsYCBfDAH!dW04xK2QZadg#$!1>7&h$P+wU)Xa95|MFAY7)5A5 zuQ3SVPM(=+nLGgm4q|BxL>ZWbG2M%6M-wj6km;@)NmZk*KCCV8EHzpy06x48N&45V z_*RB7RBt;Wk42MZ{S_y<=2$|4RhpcMsfYA!14TJy9d%1Y{TUgWR-8ua5LIi%bM=Jy zfuG($Nw?mc)`sVkvhVYE*nRYZ-NkRIek*6WbDNBFF$qWG6REkGvKKJ*s%(1;n>Cdq zYMHn8f|;(p~gtlaQN3yr6qaTT(b{Y3ZKIN$UIWVtL3k5q!{YYaOs& zk>LN~za(v46OrI3t{>cmi`5zj@3>@H%DaAhQ;YhN%$ zf^da0jVw)H9fM_iQw@Ns5bx&TG z@Y#kgrq<2QC`SFq_XpP7n)B3LQ1Zxkcd@Q)I|gYse%p}V$B21(0KjZ62?n{e02OB~ z;byJvc=#35j>i=H=dc)>c>rv-&uZu!(C99bYkd{1fT+}Z-*T)lLtF9$XNT=kISq}Q zT#^XDK+icd25-Dn;HM@}G(_A6#A=}Ma1|pu&nZ(&&sy`2_8aT0D{dtP{ry1|!Xo)f z@GAgJwN^ZY3fe!iAWd+%A2wbVw2Q)YFwEM~jVDJ{Ke`^xp{UB z>PYJG)n+{oCz!s|;@Dl3y=N){>uTk!OBTdqupHUbsKKh?g}Ruf*gZ(n#cySoPnHCQ zUW#N+RX~Lmp)J>0+izXG=+gE05y6u(9jD?)Msm){O|q#OQ65^4MmX82#Ln)%aW?52 zU;fU!hmCiW5Ho|~k7%pfw)gZ_eSwfuGo_mZQwBLvL;g_D3k}WKL$UlKG;q?{(lH6N za@(@l*rfII{^rQ62C!(Kia-NZ7Irqk6U&9eNau||BBNJ z+83xeM9iW^)C##K6ySv3I3vL4KAnUG32w+i(9;Xs%`Ms96n`yyCudb5@@#j<(CnEi zR=Y^w(Yx7tdrOzoOARr9T^l$B%3c(`W2Y2HGg?p}TN%#bADLKRs9+?P%W?wQGv5I^ zC7VN4$3d#xj&-$Fz4_q}4`mi8<1N+sf1N(4*1HLX-F|HzI}S2?n$0kUW+2|IE$MO+ zXn10X+6D=iRVClqXtK}y( z3Zm~3pr;5nx8CRGh#y7I%%4B*+DNLqS)u?pD>or76=X4;#t=F5@nNDcjM}p@kc+2e-=&_-`{Xz4{-b@w_7dEtMOtps_nHR zy@1l64Wf5ypk*^Wsk@Lo7Ad*X;hcJkL3HNdPNHe#c=pwA{Hnw5Bm6q3rlo53 zP26?8@N3sB=f@?V@_w(JWrXbi=WNoaLI&M` zy^ZE&a*j5#OZvDbWa+r8kHE#=e@}MD0SG8c&A<3gB__OvJV#~&li@VPzW3}u)Ve8< znMMPyj_`Gv0UL^Y^HsMX;+pL?NZ9P8W`w2iTGCrdAZVrbN#HBjZ0d zz^O>gn@IX-QA_Bux7c%el^x)f?&ALR;qf0S$bC*Ywyw*{Dm=cBywYgi#P@dF^smh= zae2Bx^{v-BkUrr?&qPR@&Xn6H^TeUml(D0B$2)JlKf1Hyj)HWyYu5L+IDz)yUqRsD zzaT;o7g?GzA6$ll9LDo(p4_0~#cr82HTBeD-OyWmF7(=xZ*pK+&eH z1?k$RKN@~D4VYb(+vVszSebu>M{ljcVEnP4k2t7wokC83dmq+EAWXZ63G^h7)^q-q zW5=ZG!guM!dBXm!XWgp(SDCAQvfL*G;H1O_Vzv$D@QlbC+`QCc)rr zPQ?!rUlzX034Ly-@nc$p>C&H;fLi}xaJMPxk(0XXzc{-M%NWmOXXp?b4Ns|=yP}x)lQ%l$e9HSOpSUwQzkYAzQVi0lr zw7AjMOB>X#^t}8x*n5)x9jOPD%7nRk6Ew{rz1#JxRx_4a_x(>_2P>0$S(eqHzo82v zUH`@{0qZtpYBaiK>iZPOv9SZIGvug|)6$L{62~X@6XEnsOZNa!b5vut7%m{NzkGG@ zW0w$l7XOhbqoFN1(7l&^Q-G4ZkRuKO3$WPcu!$+6 z=`u9N;v1x8d|qgmy|3lvATrfsTDB2#8!ZiL*gOZM%JV`mGF9p7wWL4+p@YsNz!$H^ z?co^+VmAFlt)OqBK78=~@7vWS1Fd@D7jbBSmja*@>1S6sh8FoQzv`qo_;qD3i`#YZ36{fTOX;qE(bGzB+G8 zy};|nAOXLL!t)QD^uku^U>oi^8NUvCo)zC8t=P>XP`y*t&w>DEaW_ z0i?A|%D7!U9bRxr(abRl*%R^atSTS+`G+g*X^VfRA$}GMv{=Lq3N*mbC-2Jl*u0~+ zldYE0zm^A_X5%qtj~`9@-rZXPNa0LDw~VF@P8Z6ic4raT|{XJsK<6KFe-jyBFnul0gAf(bzEU>cAVp@(sZWB_-nDOcHZl1dMK zlKdi8;E2TYPvgdz%SnI*tu!V6Am5<7=vGF6dW#4=eEZ(H4&Vca82s7(?_WD>d7<>` z81~9H9O{+2_h`=&ZCUs>B8+(-Kg7L1te9_h3f^06Rl8;>FY-$eL9?A!)`QllG7rQ~vXa!Qt6 z!7s6KyU58LsC2O`2#&wTTtvU&6hK;a-(*cq5G{>62;b`j67$5***03a$z+Iaj#!A; zJE!A$l>$g<9kv%YDunRVhp58y3$Ezi_|Oltsf|iGxB2T2bOA&Dgc*ExBO5G}>y6>i zO)oZLytcqUkg5G{5f5Fbsmc1+bfk8$V4MtroXQ&kj`nmjV;t6?aOI-H#?9xZb$(2E z-w6`m=((v@=uGrkfeux?zf~V9eXf_lm-JcRLkSg_X5aVZEDszu`=aeQw5{7-q>XIhkJ6p-(|i1{S*=mmkNX%H9kx+0(OO6q zhQ;5mK>{&l3p1JCz|ishvYFbvKmtfGj#onASmIu5X8S07mN zSBfZujOgah0)Xki(MO#_fUqS@y2t6YTKep@=Z$Mkr=$A=eDh(ulA_*_(>Do9_hTsX z&C1%h;@g5tC+_zKbASEp2+~^AdbgZv^K z-nCwla>9z2K^@)APsc6@2zW-mNiy+?VatwJP=VXmqTMXy(FJcXMR<6=rtn!F>Uss+ z@!UE~L9T(yGd=qc8efV$-+M;@4F~38=$!)CJl&ko)+1c}p;uV;_%4P3B{omg?_%P?Hq|Y+wUjGo* zP5Eh*N-Fe^2JQ>))h%1|-oTxjgq-50st*V|)VYSGR2tt>cM*Wzn7Urqxc8H9TU~>p5xc@M|0dkT*}$F~fs@C@j{@v0*6e>_^Bs1rCUQ%sPjvN|S)q`6LAJt9Sa`7|udmB|xbi z!zLU4cPmRd&s!jebp&(@*xae8Dbx8NDu!+T72ud3-Ix2lY$g*VRHJ*d1zYBG6nryG ztY0qiR*jW>xQu>UiQ5=Bwe=Ga}gPg z<8@SdF@yV9h__lHan0>oKiYq@tcyn?rLFFKKen#q@xNEa{$d{A$T()V&mK3ypHSH} z#<=FBa9|-^WC|a>XNWq?Y2=$*I1DqWm`4s%Nke1#&O?=)uLdUyDs)1{OLpF0`^)%W zH{W;1_f~xp!F9NuVDz5w&q3G1Ibe&5vvHRh z3~I*zU`$)-xlv!uO|g;u2`TJDwGLl?U;8q{YvJGbI$>RLU@cX`Q3=wSHaiy1b9}B z_0;dc=FLj>24>^CfiJedGWb&2f!gHURrwmoBlwa*IcUuIeCw)d*y#+hL>J8?@5%k! zk*Ho`%CP&|&)9D)hc+IQ>dB|z;f{&q6SN{bP^iU5fUPr5h@m8nRGeYuPLC? z5=x4c2+|*_(hW;3%jR7E_uNJ!NTK(|GDND69>WPx`te{`3g!86z31$eNDQ_hX4Q)+RPsdgAKGxJGb)$BSi zr13aEfK83~LuDFPqDfR|k|qP9-?KBfiwQVu*Fdn7ivS1n?ffQ`?3BWCkZq#6y*=j9 z4(^kxd))#^csg;e2N{X$z!lH^Hm3F_#mFo|<9u*J-8wT$kjKPr{XpUH4GDo}69y%d~9_9;{cWD?#%-$EjdUOOBZ8txOsN z%}%1&b-?b{^piigp6JDo{yCmi-# zQ7o($A{nsQX@Vg7)>iii!LX+~JjGsJh*o*NenYv*HAdWVc_dy3u`&xnU*Zg=F3Pe* zm%l>+VvMfX3sxZ``Xs53oFjGy+Q@M>YfBOlyyu^vp75#<_vCkGFrXpl;WtFeN-NSW ziLZNf*F0)0F@*09&fz)IB>N|OX-g2x4W>N&BmfM!NII-MH4abv1L7={(Z4@p1NHHS zf^LAz`jy)aVh5KNRitfq>dTlGW`RH?bO)c@S&^RKm2U%}hHB4V0eEm|?r&8-R*lvW zGI;gEmr_-csfz^+x^z}f1hz)t5>}f(`v}f1?A8bn%9kyZ}D>iwOAEb<5c{t2=pUSrZ?q2!s*(f~I5ul=8IMPf}e) z$0Ex9dI~a>Nx{7$xPZ+g@`G~7c^O;Hj|Dd`=U$So)(m|hP!@4(Nv)#RTQ2{%MV{dOESg%AEJa3m zCIPAjUEAw5;ZEEs_ekkRA$AuU3Q(TX7N#HO`gBA!GjW}8J4+5-A15^(^51xQ;5r%W zt_GR18Ni4$CHKSZ-sbODGeWqYb}>~SCPwlbR=u;0SkNFeJwhEo0~e^m$I|ROxKh?J zNsS2ApAD0-Yh!jPQt?TER)qW57Gyt)4J(^T7vucp8fUr-@SyYr$4(Zw}o`wp_ib_86oE2G@^m+ zy8F1{=LWj~eJEOZXGRO4EyfN&aEN~H++WnqmcA-{>tlN zQ<%7eWhSb8S=apSjCFKnl)#S*HOq|?eLE9s?=n4@cMvzbQq1%|n}@ovj95H7G_q&NT1$d$IRQ?#IG z;xFgLcehh0EpLGVER%ycVrA#w^I#=C>QRRAA>;k5;1yPO8@dPbPBlvwPH>z+56}JE z0Tb!ZfNeFofj6C7nkS8h+pLdl3=GM;qm*VIFg#P+W#zhI>9c*b_Hs4**_)3FPnkse z-;3xQ&{dfrIb~XQvpC7;MYJT)R!RirJ3pBx!<|Cy4^7lTsBnQ6uK;ZTw@H}ep!eUpk}3NgMYVPwfmI+ zpSY((h>Ewx`@Q22#ctAsh#Lg4W~xzgOAqjcI=^u__kq(q5UVtwLzaW0US0`t2M`fU z*%nT#6~6AjN5 z#Z(A5BX)k$ciGWfs+#QaIy#wxa9wpsd%=&#Y}+sY{CXC|4s~5Ly;Wv5IWFsPmiBD& zM<-v(6&C=HYhJWnQ1Dx(m|UJ|P}9vW+v8V8sFI!1DM_sj#qPggm;N5|dMyYpxnL{2 zPyM9>=BZ>m!YZo-8^_5Qhf^29a+8{R< zd_tJHp5rGN-*okLA##|~z-U8g-v5jFXS^S&Z(C!X@GAEqdlLL)aM)8vS=fS9@QGJgIc%gGLcIJX ze*vhh(}YMkziGsSJAFzTN-5hX-h$aLC`1e$eRHa@fi*1a3d<+vT2htld)kW|l;B@B zkx4IGyNXP(A@&c!pw*LT40GS+B5(O%1)b)h_}%X|pNXCm-M{`o3B~;tk>2DEgsyr> zm+gK`h2k@*(fe&f*Pb?WPAq~^;=DG0MlW7=hSxUC zp5Z7|DiAvl}V$-=SxULn;1y9;B& zjiwztUWG9=pf*K$|CMD<*ZO~m+!hOI9fAKuj?kr3ViiAe>#m18OwdVYHiK)W$%K;g z$*aLo?z^nfV{~aVuHu=0hBK!qaJ?k{*%4+AP5RV=Lp8>0*u1mZiArU2+`6;5YU7!D z@C~U2`1yp40rlMik*Cx~qvXG@Q+h97DKyqR3nko6dH)+(kf!`3sO`)-KSv>~Si&3~ ze7cQ~Q$Bto02)}pqa@*oe#EC5j4KA(5z*nMBdAcfW7Ja7d7(}9c7FGO{JJt;(zu7! z#q@6cp^UpUiLi`(cJ;Nm_~Wnr8;PeKKTPN2!pRw&zZ`9QZT`L0dbdd&2T{`^IV-Yl z^!wCy%bw%bLsJBQFJBSgWt@~8@t*Tc%}81G`AO8u#2iF(>lZwx>CUc)KyC1@E((<_ zo#7~@C+|{kO{O1}duE(Zw7%t#H-%00-UQja4!o5~)(kXC%gFAaHP6+^GmA)%0(Cj; zrM0XJ?t1ZeLA_`%w3an78^+obApG094kt`oOQFZ8hStSmhuP2q2(Z{Fqzfr_{({q` zP-F9~mXB$Mal+MZ@Ab2{`=%_#RG(}5(@mDIUiDleJ_K>BAyf(s?KG+L;KzMf zGMcLcR;baR-Yyl&;Bw1RJYdAl#NZTT*plQ_PE->ZT^7>&s<$g~WzMBc?p!pi*KjOo z7X}HvZEmIgU&jSOgs`LT-a*1os&$9UAD1UR9Nq1Nd0JLt#U{ids}tRU;~w*oMb0eC zNs!Q>o}4{c9Wj#pbBjf`ZgY3m5+wg9EF6ppTLQX)^IQOphW&yU#c5fpOwn{5_Lj;| zyD8s1WHghCgFz0R=TWAhbBS1V)Y@buPIB0k!+E^zVTk1Qc25PEs3k>QpHW-(;&j zs9!DmnE12ZdpRWKjQhRZj(iQX zS-1(?(M+iOw>)u8 zuu{L8Ozu$C-;n7e+pXJ^@SxUwlpf_l+j%O}nyVzo1=3TNtUL34cv!YZC}N8?LpJu7 z2QW9OE1SEh%rYtYk)vSR;EmTGXMzN5i0kk;q|7*Sp(bXP=hoD?AC5XVFx(rTf-Gr2 z2Srmt@H}yyp=0L4ga&6|3Qv;)6{p;j=6`d9R3uc7`&M4pJI8MXk3A$jjY|SlnE!Y> z`t{xO?8fufH4b8eWBrFWoRzBp0H{7z}` zaL}?UejsuZ!uz4sH`cBIW_=PKs zlVq|QuaI!^Jt*+r0y@~cZzGIo%1y5h=7w$b$Ct*o!;rXW-NatOd2)0?KZf)DYT+phMxUgpq(1-D#9THmmEWEL!dzdJ73#=fcf<3$z_is^uy2 z%`Pw8uE~uxdZk9UEEVXI%PHdUL}CA6&~#lkr+yvc;O+ZMi%GbvOwJ7DbwfGG)J9`$ z-Lh}7<(4~sHTikO%kB4kVS~vyf984mOAb#gHhkM@K%IRFV)&R(F}aUkG77dkZ#ZdHMUlAXY&_FbMqX zb@+%L_W{cyzxhrK>julkO+9_>>K=ZKs~cDZ-2|5X;*^J~o*?*QgeBmBDO-C@$gev5 zYjb6Uz;I3}oQH!~Z&G*DT{2E^d*>gsql_d^AMppDx(qT^I#=g1lbUB_el-Znu<2!X z60!0Ox6o0Pg4aJCPBFZ&+rBKSwpHOYIOzxgJd5^_!0uDKH1Wp|+%uodm48^pL>QZx z9QL)eJuy6(581$!l;OLHayw>4dd@^9gjPZ5VNqXBWa-H7)ewhmisB^AetW>ztm;** z1jL5=<&aC#$>rT;S2M}Ql2o(y(uz-|sW9IIZ^MNzGH{1?Sg#2W2?B!foLhL5pnxva zK#%!Y-svq*S9yE>(29udfA=E8-JU;g_JX+Exi->$0>?PzxLoGX0KfFeG3z6?V3fYL z#{zln3?lROGxr+hHhju(s+s^<9E~b9lNHnyich2=cPOW;%U@f*U7}cT+2=ARrP);% zHMrh*4Xv0w!J+qI4tfuyw&bEiRU<%nVAmQXMv0R1zq>DX2JL}D3d)71Y{t3{S&AHE zU>xK68j)pB<}~I`+b{qbc^%T%eFc2Jiu>3J{nqP$lCR2hal)?4xn$2FId7N6w!<4> zUR+fiBlXld6JCukXnmQ2TB3R`9Ufy8xF>OlG!rh_ONP{TAjWqNeZ6>h9>ABCe>DBm zQDZ%F#{V_^Jb$llZJ+VLsYzs}pEOA}9>CiWgcX~X19N5fMww*$A2FQ!^Dlwlfx}?H zFb9gG$93O1GbG=SUi;%H=bgl-bkjE zOQ`aSRg7Hwlhf}vY>w1tfm)X@GSrfHuSY}hX_^L{EYY1}W z5g8-}f%eH*uU3(MzGL4xQJjM(`%}8Cgg4Y+R@;Wcyk?)hql&v#A3kUo$MD zyX@Qk9@6a{IU)O$!L2_zpL!NOO>rc3t{iH!`@*VKe7w!w>SEY1IJcO__f|W_ALkjE zw2kX1n4b)+|HPAA)2CA%`uzhZr3vIm|SP5zs+094H6=G%aM8BOEP1t_yw+UTDQ zL^BP&lHIEtvqG57u-2kYnN{}w0UCz|B}{8}l3xUW_`WD2CL}u}5$!9+?7Ltq(ka=&B6Ot8T6xDX1VSTcLS3jbh40Fk+kP>Q>sqp8nu4uR>1VF0Huri&CP$ zKV;o0DG)$ki#~oHP_2_``I%z583|%H+_R3Sae@+f&nLs}7_U5V&#fn}Bjnp0zw{y( z&14VsfGf(H-`D~HCLQ6U9nJiH0h6mQo_bt2aebPXOQ~kU8CY+cJJB!IM@3nf zGAeFy{P-Jj*Ua}!6vRVSUJabE1KK6nfD_->8c4q9$ez~jX-(kV+0RbsaY%Su#{T{j zs>lpva?SOdlRTgfYnuOSyGTa|s@_2cHBP+byEC4H61jQH8Xcjxim8= z+z|~YQOzlSp&?38?(y-zhfajaWb{gemW_oR*0NmoT%HscE6(KMzoQ5#?guduAd&3 zRTLM&{&-e0URTeTkE{u~8F^dZd%!XGVwy)iFCON*hLJp!#B$s3CB1Hsv}U8>#YgtU1|S7!l`-E?Twz1w2t?w zUnN|C!9@A}F4DZ)i-v;N!DHaq z%5w8^{q&31*$A(7m1K6_NJ<)%zp})IClhXVL;40Nso+psK7xQ7dV(K$6yB&CF8>HM zRuZDm2E_MR3drcK?r{-_PzV0?AW$K(kH_?r&=0ERL6`}`3IdrXZzB;P5?PB4*8TkX zJIX)1U_lsD4f>+r_laUWzv^w zM@bk*KtA=|n}WOR-XtTpNPGD~kP5}MI;EvZr&!wOk5K229&H2Zhmnu~o!!B8mKf40 zanXI{5KMn+Mu7*>NWoC(nOGs2o`=-o+Zz^r<9-(9I_t(SX2_O=MPO z<(#o6zMaKT{C{|$!|m1cA6$!wV7&duG;S%r9uKiwZf_lWORm;Mf>i25(cFv(@epbu zE1b&Py+{^;8{S$R(Ve6I^bhSU>T5Vy-(YT(B0h`)Pgzw*Pi=(^tM{COnMzXF7Qs(Z zPPSsIX(T-`f34!Hqoui+h?rZQbc(!L-2Fj07zT;fY^3e^-3s(^mM}D5l%F2qr(boT$T)gKq-Zu!#W1Ad+l11LfGDcUu ze0x&;81czArHbj6X9rP6t64FKR3Y6IF*DC%$NOfWi^L)3Jj{vb!RuEdO_oAN3^)bHmU>EP@MaX9**Fcm|yDn+Oyj!zigMtAKUA zT0ALT6l1^shsMup)#_iU2q;Rs*tzrArTy7YiMimxBnY%uOHZ}mGXHKE@E)eITd}lw zQh^Nn)!4in4FSTvds>1|FLHpd7DwgM3ooJQY_{^Hz-2&na-x_$Z_Z@Yyr~1Jwfe{! zdAX5}Jbef`J(IO1F#MmSDV%AJ=ybH;gZS{G!i=#j2FX@9RI9?Nc0*M5e0WiJ!RA$_ zgdPv%9Z>{94lBjSe!)e86B>W*o>`r@j>Q~GKU0>s|I!6da5Q(N6Ap4BLi&Du{eFK6 zcfrH|47=5%mdgVKIgrQru3Hj6NChu1xh2hE zz%h$&GD*O76>>TF(*EaBfNkK+j_pcy&PS6GI{qsV`VZ+gnca^^eIt^r!QnlH5rxex z0l8aVd6n-OMI2^UGNRM}#X3mJee(M;=QBbwXVCbA{KMAtt&=q>xO=GXOdY7%ZrIQ# zV4GlXtbs(t{*dkD;?UDub+dDTHDKZ-GnsQ-6!;${Dv>qQP8GNoM}_i7h?Tm93TLVykJuv%KuwQ`dP}6l$Qk6c#_lX1d^OAKf!G6*vT*(@TX@`H2k6Yzt$iKE*N3+)nje73Y>JTGV zrU}0>lzkCz)UMl-cHo+=;z2 zs45mbq8`C)K$y^k{VMsEkpGHQ`RX;k>B$|0=-{;${n4%6fl1&jR~c2f?boinZy2lMxJt4PsA zg0r-8N2Q=P6f!{R%%ecaDvAziofs~4cR6Y7*;X(4^~(kmPsV2IH6KA3|r9 zcL^$Y8kl{U(n`T%e7`iz{BgDKBr)e^Dx;o!P60%$Wm;jj6#jbilYo^!?7`XFnV$s} z;jkI))KD0v4T1n*JM=hV3-%-t_dbD<;F)ZNc@qHm{@>y5@1JjYf76?{DdLcXsd?R= z`L@c=`2DtLzz7fObB1CXitOh>xdTdb;c~h2gsMVo@LDo*=wU({JTxr>2<#WLN^1K3e zX891(OY-FITth1FJo}leE256Zsy>BK%*cvSPkRlIH#e8upNGJ)Amy55C9n9eJb?RQ zK!!;Su#Itw2v4O_8;&QII&w7uU|nRUhFepnaqSW{uv4UaM<1d)-pn!sqchWd1g|$b z+&!A^no5FMClp4oOQqJ!^h-}pPW>z(>{|X4kUd;Mr}F4^XepsFq_zi2X=D0>GuUS| z{=`?7|C4&j)ouSJ1j&)uzCyQsNQ^$JRuY5?=XV+I^)lb?NM6_qvu4gD-rdN&OMrF} z|4v=48DBW`!Bc1Dzi*hP@Si{X&aYPV{O<7_`!5vvvF4K&FjN^b4{y5}nZMVPuxuS# zUNm?BNuRLQkTajb^-#B5rt@8W^{S&vXo#%EzzarS4fx=RZ##bv>LHU?*Lp>kVFEET zZ|MhrJm)NRa8KKCjySWIvQ?!@?FE~3^S zDY&S*SuoOwVJ@}_LY!Pj%OJEmv+pfvR~MZ>Oq08O8%Or09q}Nk^5q=epJ6yxb%JQy z$8bmODV#^kp~4NT2pffbpeE)b$Y_-5u|bNxZbeJG;ou>Kag!qtuI@&y{>q`ev&TF@ z`T$FIIzr;UD{Nu$!d;sD`mP(<^d+t!QgX;>Cl4)5aVN#o8>h*a#y-9DRYx=aM^mAQ z!^x14mh5eS3`uOQ+IX2$?G81t8Q?@<&bbUOf&vU5suJ%DN=_|5JLhU5rg-|{U7@pa zAJWIwXi~-^bBXrJK)=vkRi4%uR_wimNDEyGS;m-mXcF4@E~!(WiOlc&Mkwfuz;8%P z(pg=?P-f2Gi{~KPNAJ1vKj~j+oKg#+5UfjMG-YR<00>(jEtd2?K?w+jXI>UDF>G}bLz!N+mWx}ME{sELKJ^&4 z>?=gby3pKg5QSRozuhluMO|%2BVuZ#1r?`-#x(nf2yBheP&}lTSVO#S57G1JXnbS` z(sl;1AnJ3}nqk^02+pGLWq#$Z6Cuz;4p>(k+zm=cb5)pMteB|kAm_0y_mg7=3tD{? zLd0w>CCDD|&73`}acSO(oUd>*v3XEh@}q!s(G%3h5@IPPJl7Bfek(qV17!_5W=&%tD2(>2f6^dE zxn(+~aUrDHGarq_Ju4jXf#4mt*q$m#+;?my8v$Y^L`>W<7~S(D!=Dvq-zyRUu2F~n z(IY&e-)Ykxa*Czqng41@#R#W!2%Yfk<`&{%Zw?Yx;8FS;4LvV64=lap2>xYUTWnNj z1o$_4#djp(Wk)BN=Vhzuxd;tD-i~a4C8?{2qzU6oYe=9sbWehR4KA#;5V(fZ2S!fR4~VtDrx{kukKixp_GrzJd3Lb`;zRh50YT* zXEIeG)p-^IySQCR3Gw>%#aNSySCS`xhuK7&7!gM34>0}`UvtMq`3?7;v>o?pRDcPx>F#l6 zMY#7Lw<^`Z&hH$+U%k93R~TMpv(dZ^TAd$As<82DpSw54AIse4n+cM&0CEGK^|nI) zyVs(=T8n~AfAdJ5R_t~v2HpS4bAWW1_>xX8@d0=JYPh3vBAZesfsZng_jSf_B4o#v zL?h^uw1hqilJ-Eb?EV^kiI$|B9CS*3#H4g}`P}d7 z322)YgerpmSfZ|R@JCZZ?%Y*8*-qP0pFLXSoKB97(oKvvmz5MDrz46~EEd#s6@6(} z^QXwE03WMV=0ovmHTrbM>XN?6x<_=e44D=wn2rH<#cfL)$p zoMGaOjRUfqj}audsp#SX8>gp=JG*-Jevfj;ww=P3JG}qyc)1)e z1Eo{%Sh)xL_->2ZRowvBysM>YR&2{h>QP4^=d_!YefUJaPH%86aBQ;e;&+Yuj@@4_ zz)}9{FWF>zh6SkV&0Jq(_SM>zB{p{mTVec1Dhw71%vrNOg8!wL0uInSc)qNV^OAf{ z9Ogh$&G6#fv@Yt=wEc!*@@phpnHwB{!H%PZUMM5m;kWZSIacD0)fq?x3b-cP+w1Dg zKqk@_J}3BcWOwKVEQkQap*Dbp3P>te=Iqr8LiVy-K^I+vE->b`jcMK>y~` zYbEP={1+{5p8yKs+_rS+SD!DdZV*?mA)pV=cY#2+S4~HGnq}Mbf zAhcAhiRRuzC-*PCW2kaj+6BV70-UDIR)!7pD1<%p4z>@-tE~}c+zNmGq&XOMSoq7w z1hD!@a`5zBn;Qa~t7ID5x6{B-eUiX?DK;_wLF5O_SpDodA>KH_(QQH1$@Qf^{yLpx z!>}lLOZ7HWrs%hvaqM2pi_bRtZWWxBlswSnrHYaWd&m<@H3u22wPIL9-FO z&&O{iZaClRBky0MeNeIa2vlPbmJTUIV9>a@Ux?w1%e|P&&ve@u1KsH4QC+7R`C65N zRQvvpXKNm~4z2fMPs0Uj&v0w`TMOGE9_FQ1HZoO@QE%-B zJS&e@x5UjWPS0r6nb&9^gsQ|h$1E_Akg6>fnrkOV7tIijM0lSTZgxmWCx)qBlD_*#Kp*d;^Q7bp z{Oli&W1CKN((9b1zRWpmvJwu8UaVe0_tQ7k0{W^KD(p6Up(kz()D=b)wfJ@sJbrYY zOB#s@#Xy!h_d)Uk@5)taK7pSRyF%gd6cSdheqSYj;}@^m?6-{K$dMm&cqC{0g{;Dr zEyX9Kq7pkQ;y9ZHJKjgJz5Hg&j6r$A+UoH zIqfTSGc3{jO9z4r4>qSV8W?<9B&v#4=_5(N+q`+eSH*4%7v<2tuM3|S2>y*g5-#jm z8xAZr{(491c>G}ithD<`AwuWiPUVhjc*q>Mqp#|M$r0a5GOlKJx&X(9!Dw|}K`c^| z3txQ~4X#Ut5F0{lBSV9xucDq=fYFcD-}KJJ%rA@0$WnhKafO(CW>=M&JU_5qS=%s< z*te+Qangt~2wCu35HzQTE0iJi!Qd<9dPSkLzRH3>q4IBEID=}JGsB;~6hJyd3G8e> z-*JN9Uju3uJ=Je8;|Aw~Qeb{)9gey2s!9Z?=VwuJ%^CJ{-I9rz-I{z;MT`i)b}Lmbgtn5C?+I2fz|N=?MawMn^c2Q z&Tvki22hzKQ#*@4^$3K*8i3`#y6)o8?K@!^!Els#`)mWdxm_ZaEGzVBuOiY+@lY>g z0N)#|e~ehD8#*GST0M?EU;e3Qf}LNxy*BqS4=7V{A1J8VZ*{XpR@J+7t?h!KuTOih z%AhHJ(7Sz2X8`70&ISO|zwRqM<++Ix9HX;Xw3mvpO)LmXJ=e3TYBvTcoWPIobY3OZ z)pJe*-`sJt>F=!@{GTjyHV$`_ouEr~+!dt8I&77jwA)&M6{@RtRj6m_?}h?JgyXpBTr}gQNYm@)v;;G(qg3^`@n!ehaeIiSl zdfI=XOCj~OqKYIu1D2nm8j(#GJ1ge7s4-9LN&Pyh?&)NyU8e^9@_k+tmggsSo9Can zM3OpO--k6FQY!J?Krpj>yJJ-FC2-}#ND$SxYikS>plDAYMqHoY)Pdi})vln<-U(HZ zW1q3UNu{_VnEQ4Cu~ekU1~mwY2eL>>Gqp=b>aXrt+VD9jVj?f&u7bbHK)$f2{FGc* zVld=YY9L}$l2w?x5V?*F(H79OAar^S?giY}{pCfUIYLiIFBck_9dZOCs-97z#a|&_ zZ=HLss1BA<<@@wI&-vS+V4%T`#JWG&{iD7-I1+gRv}A|Qxz@^L+AZ6vsq9K>-0$3) zejusiYP2~Xd6Iq|W!@sUU`ogE<9! zAxVr|ew1FCzOM;?XQOf!R^`clan<3f3I<8UUI=*1PAJ76s~{!$TyJIVzZcgM-qz}r_eZG0i*C~FDMj#hUElq?VF&+w$f*AoPT z2ulD-?QXOQ5D+cMUcqRI0+)QI4M@g$-)Qsfg>t(~Cx9Vivv)npQ|<&{1q#^Kp8(2w zASIMv*a(ovKlNxwbQ|gTwqe`RF-nAnGTbQj%d^G~k0F_(Duc~F=)R}8U)0Wnp3AM- zHkwZRhG8gUf~JE+aLUX8do1B=j(FMeXe5P9eo)o3QG%);yXhFb%I6pPKi9(gy-<%} z`1<*rRV`*%@m2XJ5^;rF6v<8r=55LUTsFhK54c`%^DR5_!nSE*_v_37M!&*)winbf%VOI%mg!{gzH_ z&=D@Er7j@&p8d8%?|v?(>h9OuL;l`qpB zId`s5ylp;uoEw&wuhp&F{`;5JD+RQA@W1~coSk2vnb}||rBQ2V_q!(XFKD}+i&xRN zRxe=3f%OpVR5*q3BzhbAU(RxMXyP3Fiu3y01B*TDq zhmQEQ0ac1J3dt5%gbJ17Pb$yPFI$7JNV+d(@CA|b^W8uj>5RUp=076^w>9+o`MrxJ z@E75+I{O?IA_CRNf3cq&_#m)O@89xKU>^$6n+t5DjbamKtu>mCvc9}IO{H+{MivcV z7JY%!+ECd^8@YTmdh)yy=ynTI>!CZ%=8lg2$!Z)xTB~5nnUixyM1DCSe_*vK*dB3W z!^cx}fnt?onEy@P{zIEHXn3wNBU3L%^|J@xlG(>0&%_x$WI^qbNyr)hyeMc+>E`7H z2^iAHGn_T}yH=VkPU+vv)38n+s(~WyZ+QGGW;N^MmAYaZG~mgXPP-p-yv@qpM?X$a zDedp%OB6jQ8XFc|KaRG{jP5QQ3d5{_M%ASm#!VKXg`8!!0 zioJMWpU&w)Px0w)hx$7Il}=%$BcAg0yN5|YsAU-?0y?&DN2o)U_o}J;63;}Z`}2%! zQolQ*1Pf3zsE5=+Zh58ajUr;xSR*dUNoFo?;)VZ4XlhB6^M@iI?i3lx_dT3jU8~kY zuIsI;qJo&JnJeW^{rQixsjX%~Q$9~;7e`9?i&)ZoIEkxJjYC~{a(>}_ovx@IxkhCO zVzC6KhYS51O}S`xbrQJn&9mO^#vyMM)k5kq4T8?B+*+m2qtjjJJeKSG z76iq%B7^gcC`bM-zYq5@3|@!7=i1|ZyVTpZL-5mtS4T7!oPjan!4L!$x3=9TO`Knv z!q3QeS4#~gaPMELBi-jV28eE<1CQKwNFyM0+fsxqhrnD7p0bxo`u=Zel)g!p>6Asr zZA%{nwN3O*?6^n$P&GL*EE0oHZQN6!hzn@g~;~8Ru?PAuY z4`E~(AO1^?qE+J0E@u2#=jC#_o6O#zhmK9vedVxPp;ST=X=p-~v!b6_Rlv4QlmfLq z=Lde8Pl|@3N=}zTgYqweUE@tvo=~ej(_OQmJlPK5!ut_2LwgZ;zYHM^*Sz1V{%zrS zf#y$wRK@fs>8_+2+(<$?pNy)rwe9Z`SpHs@c|x`C4nc8yXJ26yD~ghjziO{tdyw&5 z*{ooOSpo`v8GU>GU;a*qCnLwd8-?7*zizibbXC+DGM~S4A(nf>=q}9L+g-6Mn#;D0 zG{O2hilO{Qg!$%G7?0k4cg=(w3-g<5n=i5<0uwOc!u|3a?Od#4bI#8Tr(>pv-1Xm zUl}}&Y|+MEj``yk*`_C^liWLP{T?b6C0M?c+z_<7(mK&f<31h3Atv7S_Yk zCQ3deG*gVgAwy&q*oMS-?c>>$JrlyyLNH#S(}am|>|7gF9#vwJk|u1f%=J44$OM7A z2CC{|bq&CV@WQd~#YNQi)r#;pK^`tA_GIOu8@;wVp|6q@xKLwjyZg&Qe6P=uN*m`m zFpVz?xYFmy2(9p3n-qxpkLbUxl6|N={PE zo!2=Lu~@~vBCP|$84^UJ8QC@by6d*70Qln4_zpO0=)VZk`^fT+l$y_U#=e&#lV;fC zQ3q9%l9pU5kH}&bY_Dk@7%W)Un_02SME2>QqpzXit6U+>gd+1$RzemOM;pGbo3KRk z638SslgzUMI{Gma-%mPlrS9P3Hk2ts(5Q&5KvD{#i7Xix9lHpj0O9lDm*3%)NDKg2 z&~2j&Nf}^V)mrd?7yUp17rI;aZqPd}wqnh6k4i)bS8+@kTj z$4*3`7?I5)+T}stnzDG*L`i1xNe)9ex0yZeLRS98*|844PZyL0xoZ^@Y1eMFof!8R zj~SFes>uTo%gu9=*EY8#wC`d3P;pN0t#wY1jAvs+a3%mCk{E27^40w?AMtOsHre~> z%@@}D9>pqt1*3eWZfgNkipqcT{Gq{d5+=W^@*u*0f*rZoW1M1NA2x z5hU_Oq-%ZxU1oqv%@U`xDmwAVInzUi7*03zYPL^dHw{sGkBisN&?ikm7KJp~?p)3f zBR}a1VsjnJ_8SQXfSVSm936l$9{y5ezrmFjpi_PtGHMkRy!ilpef`NP)WwH?R3(J_r4J=Q&J;|7&u^@W1le6GO&N}-G90xnOCJ-040_kZM($tvk)UwaSedOW<|0stLz|%)FIJHBlI;#7lk*novUhI*o%E6+3`I>m zl=%v|M8?Q}yeFRzujTbiF3&X5DTfq}_y>yD02R zYOckzB*I`O!D@--w0zKWQ?I7P_v_-%DYGnprV8;|59uP`vCbB!O=r023rDW&n zm?@|G@*owPRxTYnS1*%KuU>h$|!wEy4c=A-rvi0vK2(pV9bK=>I$ z4BUT?Z`E|$b!m3=`4hgfxr>X0(;3L8y(j4l@0(kKK!j5PnLD$42{?71-qjLLNbzk~ zJ2Tf5)dwf8tUBx0;*~vY$V}52!;6b%tFFeVMN`#dKe`8Q@Ud-0Y0pb*Sd-iRIAYg_ zT!k`oIG$IYHZe-!fruR6X_{Pt1m>HEema`VZS|A<*=^QnucM2wAr&oP)fJk2RrO~soOM+SrbtlMB9I4zu8wuYLnsBR ze8o4cZD+#FK}e_;6yc18@4b@sz>RszKBEYZaK1@JzCWId0uKKKty|0;LGjhQl%~$mu;%dy}}Z-%tpEab)1$Vddh1d42B8T)}hsH*d2wpFpOCdF6_s zKsYU`Y6tPfsxK%`ZeX$)?!*7#8pt5|AhE}ZN5tVIJ8e8MG5>5|yU%m~$3Xas!Vzox${Cg4Q9ktiIoHPeJnx$FROxLG zfn)rmStPL9pscjN7IiZ%NK^lRNIL7VrvCqb(_MB=!MhMakqXz=gj2dGb+xhMD{r$6lc3nH$b?uzqn zOdC-Q{VMUdLCh`sZkX@RsAV3s6*;xOv)w^|2YjBPvipKUOhPI*W@*EN99Ivh`Dm)B zFB}5+cL)x1_z_O?D{hasiC9zgFhx=3{(ZwKJ_=^oTAXI1*r`DjWQR!4qobvw?_d@L znF>Bvx9jT_@MLii_Z*DZMD%1`q6Vgg6n^U&<3D(Fl&u_mR*#_YRDzOU&hI0b(Ngn@z{QWXIgOnpyOS zwt1hcin6r6d0RKhU9}1*xV&e4ybD7%U#{+mtmxdz13H>2!imd9WWGA}wbzVh`n0_#0{1~se zrOgFvseGV+#~W|)oIdKVx$VStN```wgrVyXqQ+mD$R~6PbJ_0eUr#>AU`XjLAG*SW>y5n<6K# zi~MsDr1P)PeD5rW#xdaInH#<3aueHbb0)yWS7!V^-9dq4Dg;B@8hBGI(-wbDFz=8& zlzLu7UFr;Uciw=rUwq%q0ZdL-UxZ{Yb&5YmNb$!tJ~~P9+ESX%5-oJFsC_jZ%6)mX z#??PAg$@#hxQ?f^jo;{vc_ZW=$_G>iKWyZ0io}TLQu59Rt502K%nr8dT}z)V*W1vg zka)bsz?XvUdX8Vr+kls#>X_W)Ccx?U7I6H#F3p|F!rW+tYT;;dqeJD_&sxhLM^#6) zUCp9iU=qK+cEip&Ayf9!_G7gAl+9{%1EM@wX)7%8nT8*;+?U`Q(7Tv-u3XxXi@Nf! zKeA?4kGO=T-}%4!_-;og*lWY8!fC@P3fToOXr-TRs>Xbgn#EKz8jd7RZk$P{)>eJ* zr%Nt0-AcVbj_|snsyQN|>)JzTE?D{Ef<%bCXYxGB-??$1kuXqge)iwUCSd4RNWNU- z2qs1i9e+WtXs+hQK|sJqB>2%mpjwlWvms@MaH8lmf!J)GLdED(seg9y6uexproG*& zy>VD4c6zc9?R3h#u3ONw&b33Ru=$0Wq)pB5`XY+Dbr7`4wKgcs&Usj@S=Cq)Vu!IdUflIA37qMqHC&? z`GwTlckMQDRxP32*`w7Y=B8d;v|v)?E2)?NoJ`MDIO^|OT^>xHyoNerN~S;wtgn(m zP3+9`0To!1P?n{kFeASy(K(j%h$^}~m6J)`YSCQDYjt;tsF9rM<;Y5FNAS6j5Yrw} zF8cu@TJoWUL)q~AXL;?M&mD^2V?SBy38i&C#!5am(nWXoN7Yu;R1T2n%P(( zlhIut+~vDfYv>F3L%kN=u>FkRfG13O1&!9w)idy^2kcnb%i={8_PP{aoAaH*pm=OA|+g zFW-p}W5ucCuz2z1r9gQPnCjRo0bh!5YL4`Qr2G35pUj=0GL2y2E+G+lZ;f^L#%!Ny zU&7p*KM(P1sL4lGBR3QX~y)Fsgw2p0Kp8S~pK$lybA4)Dl z?>n-^{+tDFJj+Fz-Fy%$!9_{gM?sGETMf!t{DT%qdsQzWMAc8m>959u5TnHlLm#uc zer?&$;M<24l07te$aFBtJI1A)PrpJ%NJ+f))Eo^LXOAq!9eL_ZTBFwH+#pNL`?8@& zmBEk*95caFZ>rmpIcra7nZHMf+g(+|lg3_8S7XwN)ql0!1Ak|0{_lEJgUW+Ba0;TZ z(@8@skeqPHC#SNXk;>w9BaQP5$x_(!oB2rZ?{d*NN0ksR6&Bbx@Mf7etjYgDTM@1#eV581Gv34YMr+&yZP*0*%I zsGhD5A`6`fI8-zLHxMCg8YxvZW zVCZkT({fjqD7bG0*}Wa2OJ&0OAsNO zHqf}Nz1yV9RphMoUBpZ#yv{5fG8%9CP zjJijUkeJB&00K+v-KMwZX2V5yF<@ExkCtAmYzFkldX6Ey@8!?$V9%x^ zUT;91XlWSitkKiUuJ=Lkal`b4S}-=^@D=uoFNjmKlktH{Sktc55h>0~>6w$Dk=nm& z$T5@Z78zbVo2-P{j>=8_&5hp5{w97))$d#Pg$w=KNJYoHkgcp2>S^=mQS@ns1#e0| zok&hwxgV_27dC&q2Q8>^ur|OQc-@pC?&Rn_=1!3;UBBkak^wDhiR(aMD?l^Q+;`N) zoV9Gp8S=STwy25mR`Q?6>BC`l?SizC%{zhO$TtSpZCYTr=JO>(;p4HB6UA78S2rCm zgcG~J`r9t`YX)IR3#fQ9{wog&@ z>_RW95NGA{N01Hc_0hMT*AGhVNFPFjA5;9@Hs72JtJq}p@@65+daj@GW{gC#!#8(+ z8dq_4HGVVlvh%ylf1B@@0H)tiAcFD!<)4|zGapN!m{m>D!nuS`v{B2w|NZSC`Efr1 z?~rKnbbQVFllN|7_Kc%sUz?@+*>XPnT;i0Ob|3=RJi@m$rg`rVYxaJ*d%7gU?`ZY; zq7@7spmId1OV-?mJHFFeM5Vy!xSm(^5dF}I12c;g#kco&CKm` z9~DYEc{Y*dkj!<0&WHcD-iWnhIMI$-rk3qqdq|NuF{65)B!cpI#%x-eCyDH$#H<*- zme{O3H;5eHB!wK9Kgo%9gopflHu3i*S&O={&k`eGN4C?J=%YUe z%cdh;C$p<3;AROz$x#(*FNN;+>Neal`jBvONi2~etUyvlF=fppn|9mggi}}JaMI=* zEYKT`7We}LL%nQ`5on#l#5GzU4^P%zsGrI2tIlK$UI9EqUpcOnqYNku2Z zbHO?sC&i-WrkOfQCv`=to+(=wfbF;$;AU2W3MYBgDq{psb9Lu)qC|ukAj-Ro#aGJ{ zl8*=!{Ni!bJWm${CN|f5ev96CoRzk)@*u*gs^h~~+I{p@SxI!c(|e|6jZOl6*IXo& zf)(?GhP`TnROyW-dhW`h6wwp)ljH`Cr5v>%)#9(#V-n*TTDmtpbH@yR!#(cWe|tFF z=1}qd3H+Z@ZRVVE$jwh=@(!-TqybgRhgbJ-wI}Ne>mjc-d593Yz_0mE0;$+bUY|~y zCk=>oy)d;i^YD}RH0}@vV7GP8xmzG6h6pHC4m`NhzALDZ%_nAll| zZuPQEDkR1d&4J|`=G?NXHsP2Ec?1@f&}A9fB6THDIj20&ptAUbvq$&loK2vu!bwc) z$^YoB1%+|U)m=<5B9;k!@;v=SCqcvG`qp8W+5 z?09fHx6XC2goWIFE&l?D*@s9cL`t+@uy=K^=!gcl1v0t7@Zk05x)35YNJ-S;Jv_rV z(irq<_8@8VZayj;ayTVmwPpbQ3EGRT#8R2t8K(?<>xp`ZL{8BE zRkq#60W{$hGK2g>>nfB#DCmp%y21fF3fkEhq4~hT8vB(!?PSh$&(nzYxC5sw;!!LA zkEA+e-`SAP*&<#uYBlSUoo-cD)>Jmm86b*myL(fJh+O!2U+WbzX9JdCP|t>)?TU#r z^^B3IapB#~!9RkMGk5o9EG|Gm9B0=`C%SciG*e&{bI<0Bf!*vAnHn%Tvp?7R1PC-7 z$#W%q!BA959z}P{0+BUs(syDtv?okRUasJDRx%b1E}$nOULd7DFs0tfKRWV`!k*Q& z6GJWFR&1<^PA_0^68nO}e94D85sIZCwTvgh;k4`Q9z9SNnAtr7`AMI;Hi5MkkQdZRU3jV$~Vilv54p0n~0hl*1;!Db)9s`|8Jc%BkK7G z*V>3eQv}h#*IQ~Hod{FAv9oh^qm#NF#aX_a|9^kabvOtLi{@^n2P!F7Mt1WZ7?dv= z{!$)_w+gMGPF_cH)ZodTu5S`gg`DnO$7{}S)g0P!7WHba*Eke>+a6Q)&!m(rdEaM6lq6W%$BmxVPWFIRo}}%p(VNfG^}t zzhyrUWWg?~Fu}=&b7ei2T5{$FVy@IB{e0 z^b*DeeRdO{9y&y<2Ytqyzl)Ci@~hJdGT(k?eLaDgFKvhnGslh6=%ERqN~)RgSacNy zaSpG)12gtf)CTjURl*_-Xo@@H()`kdY_z@8gjJXzF&U(Q?w_kCS2{ODchS7yT< z>_tU?6l3mxX?h;EK7vz}5wN)(t{%pOj6&Y?EPn8$RoRUyYBlIF*J?jijlEx^6p3ss z*~id#lr{7&r?2yvP3`Y{>rU{SV3X1ezdsQ<)a4FaEVT`G{+i>Y~ ztNBRLfwKnwby@SD;k`m>JtejQy@Mmuf{6&S55jjz$JbXfU$7$fZ3nKLNjHq^cN(4o zg@Ml+4%SZM5F>;(k0Dny>{9QFJuW&)_`VlQqgq-Gf-GGwh&~M2u*1(NW93{KReUk& zKl_1S&*SDRUs84-)o|*m=7Mq)8(XFZKbBMM4u=!ggsz*U<^5e7(5zYZ%f!r2tlu); z8H(#FPAjc^T;NQh`d)YZFH>kz*MMe(Je+r0W$M47a>Rf!OOS{|7==Xq%9pdq+Zw;- zE?-|*qi6n>AQYf=a)&oWNvTPAc0=q8LwCV@jEhxjB%oVa&zgjN2hOIM zA>PLlVQD@Hxby|mWOUHJ>F;IvI78loy60(fP!&VZ*kj=pr3f4idL=EXvOrynZkOPO6U3Hw4l;5*n)LIucIGGl2-N zd#P7(*kgu?&LKYZ!A?IRT2T+xV3zbQy>Tz^0QD>_nXKEI*otoO?w?ZZh0&nt_v%52 z#bsm_e&IcmSYgufx2q+iB$I~UEh=Zm;V3wke#wd8!1C>mk-Ms^D`eGklsbPCmHV0c z!H`MvAcAkAa2-1Z*4c?0kJNK`XjVJ#!;q$_-TUGz?qbsY2IfLR9WE=R!9qX2ZQ7?NV4KMBByC5e8MZ~M(+P3OCrf|ga`p1CN;z9-U_63zdla!EK zEG?nblT|Jdl>hkNpACUjl2p87RWhoo2mMK@iOb2ggN884Q~S*Wp{`DkFm9&W;R@tc zHP<_14-sIL4q=%m){}rO@w1%9wJFS9Y)cARm@R${>~YZ8^i-agEd+&*5$f({Q$I+S z%YZ@t07tneN?CO6mH!-)B5kL`$Afm= zAkTcLumLA`x1Fo)JfUSSfyyk?M5@kbpsC0fQfFC{9sn6t$hb7WBW1BFJW z=Ef_I^!BfX&S5`dxza14XLbzM6#pdHdP>|X2Pee$Nu}lLJqfk+I~rC%hN`bCQliX# zM|!8^Vp&%uoG$gfG-1Ch)|sH6KwyAggWg`mUdwEU;qRCNMg`;YdagMvONAn(CUZ)m zZr<|TWR*W#?|*f-@-r{}O)=m-3-!?swHj|}55=?sZxzs(O7!Jh?Xw#EUa-ZEf zk0Fh2Z=Y1HT%spV6+b`iz!`cTdV-!ZA=+(Ozrr#!hbZg?~Q72cxTlU447~9V|Agby!`nowQ3@1ib{^edsn3W3K1qp0@*3 zFDKN=P`W){kaBg3pT|S6X}L{Qe6JGXrgP%d1!Q%2WECwW)?FUtT>k~3Di;XGNK zh_#CU*8&;t@nwt3DEy%Jmt@Sy3JtI?mTDDhei6bpV%it1uO5xdr)44`D6gRQK_uRZ zZ|cj$(N}etgK+;nGzwtTII&Jy;St4ep*p4nE#zN(7btlB1M)#|IUC};!dB~;Yn@cZ zr+H{ff0|u#!+WH!;+DYyTpInen)4_JjqkR+Be)4c9_?12!vP$%%Z~~m=Co;X8*3RMpAfj~4Azr4bhCtEWO~0P4?N30BiI8d>k#j+iM@AC1BG8ZG+i8}&5RNrrdo-9 z$>|#sh!<2-lV@O7Wo(bUhTTvkPr{S>?)ISy*V^)wrMlzj?=nnS@EME^+&bToc7nXQ z_p_PzGHSK$K}Gu&)q|YI=XQ@3;ND1B&1E(@=PUH2XQ2ng7LbZD>4%?Y5R>^fW!J|< zN2n)FVaKgog8Xi=qzUG|+;_$5#P7sw1x0@Gz5TtmZa0u=%Ue)>V(H*!lh2y~YOlD# zNJ~XMX0N~NMz#CbH`E;8 z0PQvk=dn)(zFH6%Tx^HuiOOmf6ULWt@>KJ z$v;-A<>j8{*d~X`)BGMUE!IO-3i1m)t?^+Uy(01;cQ#)o>4j@TJ|Equr%Yw z#Fs1Y&~XIeD6PDT4eZy?+RphyAzJhbZmp60mhx(`b3^`}%!id<|1(qYtx@~;siPPm z5Oi`abXEjcvcNBjVh@0%yhmUPZYK+XrOzKir5648(apEvB4)sTb0cXffLAoYk=&F4>V8yfKrGerM zda3_*xbO=4ih+V;SDS71(a)NqA*4F&B)TeePG$ zL`s34hG15sKm9OiJYDNkc`}0lU(UQaE6V*}&XD5ElbbTK{4zYU*#+ zmbXr*NW03bA8ZYj{@Z@|HAb0G)XJ%ZAG_gULBoHIESM8OI$resJVd*Z9dTJ=2r(~C z{`OXyUc3i6IXP#=o83@D<9}l=&aPzTP z*$JOVDNG~GkC8GM1v2o3=lb79=T66HhrTYvANRi}D_Z{z-+zq<^7Zdku(l@Mo}^nk zSY0$LlL+(#3vPBimY+6A-h!ex)h4t-=B!9@YOksOTF}*4M0fQON(?u4GIIVOEtfY7 z$Re#aIs~pXALFfJATMvP+A^AMsI>~V@27j|xjO)ok|9ZU;?1=pCkYOGOH@%td7-`5CU9`b+h4C}kyYtUo(A?+Q_D)pgH_;*N55}M-%X`3M|BY93h&J8d zRmr}cI3Jz6l)Ba3Kx*qW>?7uMFwyKH3>b#wwL;Zq2#$C9CaT|e>FuI53FptaeGE)k zC>SF==VHg;*t6!RAEVpLFlU9lXeFGj>9;RDBYoEQXfUNS!zmHSiGJ597&|%5E>?%H>!13R@I0P$}^}!=kWD0wn zecn~Lj{?0l3@l z9bps){yJN_eBr@AX798wg+RJ9g2H|9|IClV-umL{M!)&IB-n87Q!=AHe&^Z~5*V`rcE+5F}8Nu*yH5C6`jTzy@FuvF`wE>zJ|>5jttAIBUmx+;MXet zEn%9~qJdgMZ6?-G%fx@P7`sy>I{lKD>v>gkci>c=6PEBn$T|-wDD3FQou&oZyu91c z4sw#b?zg1-SFgf)?K27L>Z*yL{1R^sf$4TQG;Eayz|-Xi6Dmkw1%GTlNFZxj8Vvy? zZb)akK+s!qPb4XPZ5F>gVe9r$_(56+GR#fIs@~aCzrsc?ccK(b&XuH_n2!{NSihId za!k=oPwg3LIFQVkht&3Wkon$YIvXuAq45rke^wFiDm|J{bwmNKK~`)^-CG}}I(i6M zVQT!9YlFWE%Xuridq=tQ^J05Dk9_uTUg(gzwXE);gpkheluu-O#a*KP9Bcq3P9VsroubVde@+5rohgB0{nbH*lt*<=n5BRbExCi+t zzKp!YtF!cb)<(;ZxqJb3~t7&e3S^!wueiW)Akm5 zwY8DLL_Mx=1ogD@A-?~Zn4#1u5;y+MC)yT6mhd|+mw$>}>Xv2NEDfS_^l4&}Z0N?p z!Z#A=nJ>YkiDQ3@VaUEUEUsXe(;ipdWq2(-h<(OS{YLQ0vye~Y&l$x?S9`82f`#It zjnVO^ep?+91fjhvi9=C*t^hW@LQYo6iZN1U*#hejVVL8E(oo!f0d6V zZ%Z~b#5{FT{RhQ z%O_c_t+YP16>Z5Q_(4^D%{x)Ke_wuq&ce#O9lqYuI1W>W$4Avpn`CMnnY;<4zxRaS zK|&;H(?veKdme&MdO(8xR23JTfdDSfmlFFNJ+<%w#k!s>nVN7JTdV%1y#B=h4(GbygaPFwNX$RyedxeP%PMv0*)^~qCNW-#FVy` ziP)h$t5PUkC9RQu>-O7mmJa7HT;r3465{$}PM6F9){01QyY>&n+kAuj@u7i(h-khG zpq`_D@?Amxy{7xIy3Zqi$-DFve^}A*l1VWrdVx~rKt;?L)I4VXnX+I4s6K~EHV3`p z?B`?Ft%};9XL%>7X=-M_vwy1+H*yl7cD`Q=e9y@>SlzC9I#7|1f4!Iw5R1tHTwa`E z!M@1AZ9As*$0$krE2G=%l9I0zQ7Y!Z4aWZKNHP)z&iLO_0{;Jl05@5loh%pDagx%h47JQ zaCP_PP{0u26FqI?i?O_OK1+3@B1|;KvGJG#xRlvb)3;qum!6~B#P8o>H?_sr6;?nE zA*UQ>rRscNV6!G4tId(EtGJzMjDO>Sn``TZb}kz+Ms&cyKtj*ZO~Ox2_EML=ee!oL ze5xKPhJbw- ztW09nIxeXcRgb-Goa`+4_Y%^|RCOCe4HzA%ioGJRap;iGu(^i9YH<;^x=R;WU?1Xb z1s{eU;k>+2@?SfGR{06JE#HfTbMBleQpEiiyT~6d^rFYx;F8RZ0#urn3yi3296376 ze3cRE_QyYIGR*+ZRxnx==p;J($8_IdocUruOBKGDORy?BAa{9G3bzw29Wj78I~tmF zLV&fYpxJPCSWcP@G14IH(7XnE{tCJ`Pa04v@D$=9ksZ-8oMgeg76K71e<8qWXBd6g zROa17`s7?@``gmX+zt`!>t|JDFF8z*4#LfM{w9!oUf;Z7F2Si3{%(D{KUOSDi|yrb zjjSPcqx>d1A+f-HEaWq_!M)TZ-1YjgT3D@Av=*s2S&{C+BG5h=eyP-QXWfGl3!E=q z=OVqE@f9*+S5r6ec@UH%zcC_EVDqgT=oZOWxN%rxG@?8T${L=rXT(Ep5X-#$54BL_ z4Q&Zo|N5OE>pz%(0eNyW0=SDHQ)~xGK(^R3QGS)n>R&+zXkAck_)E~UWUXIpH_wa~ zs?nis%z*&r-7{&wV$W)S3Tdig9A4ozc|F`)gfTKzgu6-Z8XVs{V z_f9d0H_smlxz^d^K>{%Rqo=`$Nf-;CiQk7ADu2$`-%Fl=nd#O8|#mxN;;b*jS2`akdToSapo!C*MRApb()- zFk&}BzVd*$GiYyENIm}G`GXnLW%6%ET7IXM2PN0TSBtpxz5S86wn)6A>T=;AoR%r{ zqbC}l+=%xW7gVG(?WS=lZaq#9_ArZC|Roh9ur|2}=G$@da$ z-mXj><9(+PzNk-Qn*1TZISD+XzowgXCFJ>O*r69UkKz9&sVj|dnSR0p!&OR9Ty4}n ztJ*GLF)-nCA=pGZ9Z}V9INQ;kW9$}N0bc|JI()CL70wI21X3@y&nYf zUe>S?9Gw7#%j{p@oPchxsI1i|l$C9`jt4WH<|88+!7SqbUqgJKG90tDM-`O;Qplhi*Z$G%9`?dA77{DTx2(>kGi9 zUV3L?k?dG~+0>mt;)Fd-DB&C)u`bM``KyQwdzQNT0-T(d56eQ zb1xA!cUqk8!9Gn z>->K3eew{r&{rkkY?*?Rxv<)N&qp0~Y_rye1yDWNzc@odhWx#5_xKeI)9vnAk zTl#O9^6Ydyv&6#&Tcs*mhuFnaW!Jpf!+PlCWU@RJT7C+|qPtbEddVGU-UR6%2L;L!{g+S}% z9}?Av^eaaIh8|OzY!M-|T{5pXa&t?@CSy@QN)q>t*0dd z$Du}j4xA`NhTwen*xFCr9^$70cC4?@N~iHyB1mvpLibT#UrE4+TPA5Hu(OdFWn)He z?!`Row(hQo478o@8(q{(OIK$435Fke;k+$8tVm3XJNf@Cbl>65M*(lsQ}#j5Jm%m2 zgVeAM&jTOsQa9DYFhC2_Xs&MLeAcoNi-<{Y(MzIlq7gU-6Bg&|;pD&+5e&Pl(^iPC zUIawDxW(UVeoMb`GDhi3lDdr4U4duZ8U3fJ__FnB>Nf0sIoIN4Rrs}pt--SoO z{sKI>f>M;at8Kz}SFW#Yz+zV078eHWc;!F+LqrFW!%zUF97U_!-UhUb&*^tjzu(W7 zN0dK=$LRC}2dTu2NZD%oa|@CmxgPf8z{TKV*a3h*oSQ)=G5fW|q7Z0f_I|)8p{x-=( z?LS`uz}6%f^TLKdByr9V~~(5S&roA2`iS+s*~&LBf8+>l=#FF4zs%3RhFNK z@W(eBr)FIV^k1)PMJd(dK55a@ z+M>ke&AOVSttLjdnsGei-TXW@0pj#u$RD4!vZM^~DVZH#nr;id9P}GaVu8-8B6;=8 z0McaBzM}S0-yDt47vYGx z;SQ4W(Q3%PVEYs_AIp`)F_Vy=nKcD-cy9m@hcRq-we`e_P0fV{T8CdUj<;GHrg!D}nci>#j z$`z#UroZ%Q5||DqY38=;8jVQfMo_KD2~tgFjuCoB3~MvG)=ibwJrLn|^OW*tUu>fs zUB0oNXMq|Oj<}EeECNnyEY~$Lc2oOj=<6@5#2!o}E&mTUT@_YtS7(1lF%f*{KK(${ zBqi@AQ?u)+e8b7t(#hHf39AUh=_Y=w2No`V(8S|lfDOpp{%{Xfpx^(v6prc6B72o-FTwD$YO zASK_N30nCH4>SZ4O!4N9zUL3@hkoPvV>gGvg(8emI zxVR7$4FCa;5|ifoSXR1(>B|$%fzW>VL*v9W0G-#}Y}x1T_L##LhJea zrH~4k&LKz-J_g%jnM)#+&EHugvxIz)vH9Ye(S$)#F3aU5hjgLK(N58(@Up$#G}O7% z+>omuQ-(dv2x5z7epgYjJe6rGDT$}e&D#HHbD_&%-s$4LgEqyJx2cfis!Vstf~9Od zIe1^=w&r4ZL$W^g=zq+blWDjKqU$YxJyW#zbf7vw%Q%i`=?TbywS`QU4FmF<31hGtT;j3b#lIYi08r}D6U*H{{ z_E)y1El%0X);bNQ}J2u4A@OWr9CLyKl zWkOisx2+1{cFNb{`17mBJB)rY%G|}?b~>lfRiSiM?G?`3?1j^AJ}+twm?;E_(-@2k zC-HZ20vqCRv8lUR!fdq1QC@rXp=0G@4RHC7j8vYDBSI*cNE6Aykf)@s7r?@Gy0jQ- zpy-|c$B!ysh=l)>=ph4IA-BX|Efo2hjs}q%K2+;HhV7Y;3)Lm_GexiIFIS3o{Jt=| zt~_VGDpnoT@kFMIJy!h`Y`g5QDaWa4>p&*GUHhS&smtPP!CM*Y-(_f0_Gj(R9rV*N zKS6}BM9Ndw{NtuXVujL52&WBr9)ZNpEaE=K-i7zNI-~}#Cn;^G%tJv(GN<3T1(>)7 z!BT!t56Pp3ADfa3-=QbjegcnHLNcRq8s&MZc}&!vweSp=s_#ykl3ANyFK&^th<)W6 z=kzP&=b0k>DEbxWaMQjLcE!Yg`Lq#uEc8+KJ{tM%AvY+m*$DlTOiPVe^1v}>e+N%j zi>)OO4Fe@jxE?yIHUPI}ZkH~moRmZr=9SW7T6Xg+ATF}Zy4bG_h zA;?PL#LT;s6FS)&F=l8Lq^VT}DJDAPq@O=# zIO1pNU9%1r189;ID0O2rKpxHE^P)hGRyfrN;VEC|1^( zG)u83`Gl6A19j{VJ4zmXAsB_d7Jq@RN#Gx0N`}^{vMKqZ(I;2~AXSVHh(r?XIJ6C3 zuT5N9GO}>aGxoXPm6E%sgGb03>AkE)k2fbh+gTO9QY?{2FXOvM^2Z6=XZbuSr_4Mw z%PIk(^5At&-?r|XR*T{t0w13QT~q6QY=f}Zd{P)G#&5U@8_?u=_}4Gw*aCQQp*)M8 zmnj^iY?E-P$qgyI_{3-I9thZA029j>SeJtCYc#Wn-9bl|91|9qOiFct6v*a4L_vBd ztnC1SB!oTz*yC8!bw<-%YbB`zluyk@xf8t_7Ro?>1FnvGhKsc?YsqFDUtWn-h14_}(FsfJz_`uK%D zcV?LjT?+dzC0l6aQhsM?6$DqfS)2R9F|L~Pu{!pUA#M*~Q671&&RaNw*97-LLv`^FhR6koiJ^WHO z)u9$k@_AKOy&?m=t2wyjb_AdNV!gqQSjm1OqmO8z0dlUEw6-OtC30e_`fFX%q z@8-MyR5SzMW<{olSiY^uy|LMWh4^myKNktX=9lmpsd9xK4!+PYXgcaD*4wiKlS*;O zS>KFHAGww#=|Y~V;5`T}j=PBpq|X#;VgE>GC+7lv0PzgW&t5C^*5yy#x*R&F=3J5j zVr+30_3oK@(fNS&J?Uk>Se@n?s+kEEuw|C?;`E2$;(L(&d?tQd0AOM z7vJ~OMt?sT38^J+F->Phmem!TzP6nOmcM;N?HeIz)8& zyy=y#TfJ?RhGZ^No5TqhFp1vW z<4Ifjj(qG#ReoSm$5ZmgR=ydam(r*nJ&}5Ef(6)*DG9@mx~BQ9&x0a|(%&l00Up@e z_GjwOzD7WI>^~$WXgJ5yDtznuEdnv$hf!a;RiO2EAhWIODJ-&Lecs)d#h=Q^V%$%J zjCCazWJfbv4$2zn-PfT;=989ebr(k7k9gUs*wlhAVOzK}FpzMMO!c)SCe7oq@lwvl z0{YJj@@G#mqSxGxCx4wOE~ccjhOknitczdF`;=O%Z~h=|#&DBn%)>+;pd% z3dnIjl6HRuEXY4`)CT!RJCqk9F~%+kRsauvVpcvOI(<()yCb}|vk}ttuN0%rLTfrx zOCT}hSo(iton=&%U$n+&=#-Y04r!4RgrP&E1O#bRl#uQmN|caRLJ8^at|3G~x{(@? z&Y^~ycrX8ZKi>Dlylc(3nRV7VXFq#C&+h{AsRh&7r&EKA)qs&@$0?TH@;@7D671LY zuJl>ME+@~62qH^GNhJqekKt?7f*Nx68FCaSOpk}Ldd?iGG%Q3a>?QW$QNA5i5(v1KFe(6zJ~Kj)R_Or_b1Xc9AeX ze!Fp7s4I|9HK3$XkL_B{h&ifsX-l6qFj3!{M2K;af*Tj-MDYAyI?T+M<}Z!09mo>} zb9jENm@aG1I<#hcPZC|ne}=&p=4L|rTVUS1^7t$$u8H*!NI3^{yp^&E02(FWqN?(6-8UN4*~ej#?N6HjJCNn)|U ze(~}oFDp@JmaLKN{M%3FlQ%?NmK7()nx7o#4?{C`pFO+`{~+KNl6E&Q^t#;@rE=5C zcl++l5ej|n9?^Ecb88bK&MD@rcuLmZd$Y6xosIJNBIOX3MDiAWnbaEmN58|Y>$<>c z4+m^Kn3KbUstoJ>1;1V*Tdih09~@a@;q~(vbj~GRU{OpinJ@~hgpaTJ2PblrLdBo` z``X6h@`>gxa9p$*ok@rJvpe!CyqNzGO&YLx%O1 znJQN!_`>AT>1`Bw@4o#-9D$qbj1MZfZT%^xJrEfRHnPTXnaRV29$-1c zha}ex%vz7urL)xFX5gAFfN!trv%#)FfFOaj%NaEJrp5T$xeIpVbP2d5JALxQJrdj8 zVzPC)OTamu+M{(;BA@<7O`LOlaKKrG3y3}Z^S!VJ;uPX)QQ7(Mrpx=-W-~mOqb0ll ziA`kYrPa4Y{4K7xo=Yd$Z{Z&S093xBWknYDs!C3FQT0y1;DtgV4KmOV3UW8Fibo;u zkw7sU@Ju9ir8%~zLqX3cN`TwKK2}p=KjH;xa9I<5Zt~M;Kej*H=%8iw5Da7YDyu1A zf6G8%vBu;`gp~vpzPlHzl7cFk|8r8>mJY&n47NH0K?o8Q^-$~->h})(A;Wq zmBm6!xvQue2tBU#J~=aDzC)fgQ4oI~K)}#OUn?zuS!{0K98}qII=BmagYok5NCWbn ze{#r@TM<~zS6@=^j)y4ZRD?FM1g`wzro&H>fG)jvx)7Uoi$a)8VG~4_=!jO1n-AP7 zG-rtbcqKN&9maf-u%mt3J98}w@p=5B(0P4^{9N=Ri=*O%*orh$ETVT_68M8Hf1wxMgkW{(6*M{lL2=p29WjjAUAwFrl zet0>*xq5VLy%J8uv3WUfcRv5@ecJZJNBZR6FW~(_q9?N#oWvUn(Eq%j@Yoxy;5zhE zi0ICB^+zo0H7w`9zsmO?7W$tdQjS$Avm0G-d-AFGpA4w%Ezir~J5F?qmT$@OGFuw5 zo0g@`&m;jF|Bj*?r^!&~%Tw)cItNcB-2CR5VP(ADdv}F2^6#H?B@P6|vg1qn; zcG-3LK9Pd-M!s;&b`nMMID(nixaf?9`=Z~YA9qu}+SK}P7!VHtJx+ zv78-e4eY_;sxRkql@mhgl}?d$6(6P^kHRvxFbHu_uH*OHJfrC*EBwJLa6I#fgp)Z@ z;Yg8F?rpmve?uI-qx>?vg*v{9+b^NoT63SSHiPsZVa-k>5xjbhr*)ZP!PNw$CnLO+ zG0Jn**I-5K$KH$$(FAO04)PR&WF=GPEmOSWw{g=mtxd$fMdI62KowEAm16))Uz1rV^n z)*`Mc-FhJAfcGXfVrrd|UD%HuC%)x=Fl5ID2EpG`hVaeWD1{TzXY{B%0F$r4j{T=W zpWqP0{>i7iLS%kI4ChPQNkX(TVaCH2SqVD($x8v`=0$nWXIh1DmJ@6U@zBwN zGGQXOx^0rmGhdpCv3{}_5H#hwYaC?+DwjVDmn?#TAR5g7DhzG2qRqsS`9iy;7FN~y`2*3;X3*IR{ktyF;Xd(0muJ5aa-2*K% zj9N7w^bMb_5x-EM0bo1-nHM7z<;OZg~14o*OSa>_XDPah3Z30syh) zG5=OO$KLc1C2HCk8;q|~MmkDm6DjsGhi1_{5;OtWr9otfCQL?+D~LpkTKU75B;2*g zJ1?u~FE`thMZixb_bm=QYK=W3<8Cot<9q_KGZiG8$I0BWHAfHsrb9{ZH&$vuguiG5 zcVi+_DSN)Rz3jF#_lBzQSRR3K-YHjFAC#pkx`1dasYIO_)9ltjH|NY1-$iWxTUt5c zU1n`w6J+J?%Hik)T7mowee%1fmHs zPFEteaZayTAoh#s(^+FbU7L;HOIf7Sc4Llg$SLD>4%M-u{&2rc6cv|CeRDek}sm2^fnO$MPV$+A>Do1#K@C3IZ<6?CWDGveF;LejKoj^oJl9q9x@{K^KhlFWr`_? z6D1;Fn@|!3Khg1CfZ%I5yJFt08s7z(+jwrcnjbqEo_(vIa;19TF~g^O^ZH(#3O@H# zuO&UMS$L@}*Fs4+{`)1(cV<;AIi9y=84DE-5%b+ntga3BJBcFpHa0f*Q_&7{z31n! zg=HT9+5^9%d7gagnf2YKa6^<6VY?f^;5)xO>Q_SG>dit2CMWt@QtsC}TS_tZf|44w z_KKV&RN&KRU9)+#J3N$>tY0~#KnRXKM^i>ClVst5?#uhH^+9M(IQ;DIcQg5HnE}<^ z`1!nEaWG&tB<*$*r3JQOu*t1Ltv+o^ysz|GJ-}vQA)*ndmv<@xWC~V-FF7(i3T4UpJLH z0GB9$t<>=D8=YAw!KZMU*O)4{aQ5ejEQn!IjgSzvu}AUAYd@Z$BPliar;2MnXV$#6kai#cxoQj z?p2+C96D>BgZ5)PEwL3LM(+4;CPbdsn2gT6<<8aaA_*n-7R3#R+^Cz?oL?8<%5zPb zC1+b@%eu`EyqoDEKTc`B3=N@MB4nEaw|1d!r6aDHjv4ya-MbaDt8?&fb1^*MUf9B$>X{9s!L2p^cx^T>B0;RGSu>hGTm3W|p=R3X;PsreFK z3q!l(5-e6kavPd}AaNgJtTkDfmFytgtJlGxE6(S4rgXHCnU~ehz zUGRpU{}tY>6}{T!@aL$tC+V}%x`xmBjH8=hgmrzFoIfACZ5JBsxgz-hG98xNZtF>R z8?KfcXSq*FJgClo-(vR*Ty)Z$mBIm5y{>H-t_fLMrvl+ysvXtOMIl%>-V-F8RpMGGy>dZawxX&!jF3V zal*Um7t?PehgX2-c&7!kwWw;0uXn=iS2YptyRhAyKfWVPT*;Z4#=lo?`^OdLa~8Bd z?}H_gT&ox6NIc7jL7k;Wt-Py<1s?mq&8li8dmoB+aw@!%i|A;cr!@1$F2Qyb;8)9# z!lh+r*iHRU*|Jh-#TRD&#u>r}$Ew1Jh)oQK;&YP7jMKdDA*E+x{Q@NzLJxuIm}`r6 z7FQ12lWFB4^5{PQ=FApwfK#lpNI-4sR}2KEs+m!J0ayH~0;BBKuj zbt3ikjirBzmdJyB9APZ08|s-;R7~M{gR^NtF6btc>-|^Qm<<0hIJ|88X*)1hDgnzr z=beTC5Ns)nO;Qp2k*0W!jUWj1)Nx@((;=}fjC?;P^2z8ha}CYhlf$r;bFGmOw>Q=u z@>gIEnJAVDE==~dIC&-KGd&ouZ=nd-p>^5`dqN&C!#a}%c~>{=M88AvWAWbxbWJL&m5uZiOW4}8i;D$f_v$b^_% zqe)ZA{#IMPFZb0Cb$X<$B>enOJRYm1VKb__s3tVq+O_8^A{tZi6XO;~R4OASL8ZlAi|q4F6UPbzn@Q^*qRn@7Co zKuqOTxxsz{Rg{Sa#f;`mL)w&XtOx$@ty0LGRNPeC77J_KP^xH3G-W5ElG{N!&p-!Z z>N2ACIfXjmi;C;vEk;M<9zD_TV3{Ts`diB(x^3s9*mMg|)r{{w zckL#fh$7aw$T!cOz6wd$e)|w|^aCr2aD9A7265>=^H^p$o(C)OMP)iclQ1C`IPDZ2 z6YjPKbzVWz^b7J_t_2L@qJ4qgmW$i42a2pYrkBmGVWXLC=AGBQf+x5tE(stsnpEyXx@Us@=9*XwXe-*rWTfmv|%z zi^@RvfyZKJPZTWTr1bbl4!MXec`bF|FBA3Uc>nU7bi4tFM(Az0@R{n9rUmNCuEVTn zP#=1P&v&Zn;`udLI|f*7@{Dr+`OBQytw${#R^To^7d+FL3C!cnd6L=_H#&4`*nM`s z@V3mN|A!^h!+VmaO~a_BmDHmTu7$DK3gZY~srs|s-~_0YzEyWZ`l=5#yI%LGFZNVd zTypRcOka~Dk-^XI#0_wR98aT=9>877Ga!965!iK1xiv!UoOa};af302{q-FRU2zUo zl2R@`Fx|kM4~hqNS9iS03WoJnuRtL|7dGU~*yZG_I(>*xVvOITrlwyyS9!@;eG z=QWDEdrZ>vlXckTz+HLi_5hY0TaEIM(4!Sy>EYp>EL_(k#jD?&#REIWNAKU;O%7D= zGLV62$rx}$AM@B7cXbC&c`2Eo*G0-u#y4HAkL+n^H3Ds=lQ@=tEp>G&6wNL^T2C2z z?93?RKL2aYXX%#$?&T-+S(-QbYwWHEJ|{s^2{Bs>Kye!SkIQ}j#ivHEvmD+eQm6f2 z>@)r*PKo750u5vbP%Ppi`1*r=W#oN@L$c*@lwo2;PL&`2*VHf5SA+bl`*yXc+pw63 zHYb7if5lOmP}F7aX&t_9R3bDx0v=;We({m+3X+4kz+oo#rfe3yeu0!{H63JLaXEo3 zicLG&xa{GGf&-Z6oy`6`8<7=bfyheQ<7T(l1~8E~LWbwXyPh-1{9HkNDYeS}me<9C zYx`3)&40^DtJT;+d1a7g+bm8U2<#Q;s4MTz&q^vO%)0y>5Bxc?W9lttM9$7Q+1XP( z`3lAg7mg$KV0%9bQKV8MaIXZQn6w+Kw>| zd>hbQ|F8;0Z7FP_(Ag#$3eEp+JV&5*CZ8WX~6Tt%r^Zd`#EDGZKo? zE;p6>eu`Vp`Mj$4athaY>ABRU2uBlY%mP+8(iv|-T4XNiBYldTZvHTl~8}mI!0mpc1V__ z5k&HoT(XmHIoCbs9=ivwagP;K4q~=EI(Q&>dx+lVgDMr-)6Sa(#T>DV1{?Z}s=aos zs$mB3^P1@uJfb0)s5dWQx*!K0wvjFw$acH);)b^_?F{<*Gs&x%Uasfy6c-t@*SHXl zgYXFNbYH5lJo0zgABRfFn67Hp%TFr@tz*=##@ijNAGqyug;OPX7kSdwxs_`0OJ4pS zgo`z0l`LCycS+7y1E-`KqMjWoN9_9Jd?eA-Uo@x9vvMX}UpuW5T*7wPwr+y3M_(E2 zV+G!D&9MEo#SyLY^%3n0-H;bSaro*(FMbd2tjygWUp3~LA4a_mu(_x9Jobg&<0r<> zO`*7X6_u^L=g9)Cz2CNnj;5O>p@`A?A5h=vU{(k){bNGWHz5BsdiKjSVFDDX@4Gs@ zd8?Ym~mzfHY~H~Y?8 za^@Nw+Tt+;%cyp|nX7}#;cyn@qW3rVd0mx%dJHt|I4S;M*vaLI$Xb}-Pt+QvwKjq~E8&;1MA&sRk zzV`{R=QfXQUu-{?c3M>~{T?Hv(RSI354gXl}^4mWM9QHIl(NxcWfax zFf#whitqa5+AXKgxD&kY23%+291N7=>|;f3=P^N#@(k8R4AQ>sgsI&CKV3|wApQUT zkbjammClfw*y6)@FnUB;-gA%p3mGSQOP2cw-B0g&L0SqgMV^=iriosa9hFVb&$bAC zoG|+LYR=c6Ogq>RRf=sn`+86AL@1}&NJ1Vs&)2SmGg~j&V3)WBl!`b%{Q5!8+3|RX zSv)K%R$@ZxgY~9Dw2jKaG=xSz!}`ZUgb0t{pX&!sVa3+{GjDs@7e7S!YJtAbw}Rsx z%Oq^+|G0w?ec2qaCJKv8N#@m%16C!2R}6@`bVFOiE&qicAYyK@P?=Soq`e+~uqFb$ zIp6w4`rP}_=0+Y#mMSA=|A^`JEW_)SFkW*wlOBxXqsX#TcdOcf>Qfq@mHC>N>TFh6 zxiW|1R#}NqN`W81r;K+~osy~62Tk1^+)_@YJJ=S=X`^G;KBs}#4~{QBA#RJVE)ykW z9)8pXi^AF#0?>>U6J8lX;q-dRKGsJmM&UFpH8=7VZv!~YnQYX8EA3CqbqC#PQ zqT-L{XV)GH;okdcLk_ht@%)wl(y1qnW{9H^9Fj4xIT97}(G8X4)iubrMxZyVe%ls% z4a%yvC<7+LetPra)Fl}3OHVaQjbqLKghr6$OcX|+H5Fq@&n}F5eX&@89%|Cu3}DWJ zDmsI6v6+#NBs_{N)vw93n)PZ!^X@>TKvt_0oa zrg#5HSg%@SGGs<}G0F|-@ut)EPo_5Ciu)=(R_C<;PU^sA{x{%5hvKy^9qSx?PqMg-U32*L)0?_I0h=hJ=g=&)6)V#-x{EtsjoO)Kql{|Dx;Kb%WRb z@*r4K#A0{T+Xao!S6hJsY+`e?>U}S(3F1?hB;!jv;9I}OK}M{ zV`Qv;lk>2v?L7s@%AErJ+F`}Ftdz>gYBfds^3Xr1s+>T&9@4o_*1DsC$Y-prACRI0 z9xun|sW7b4LRu63EDID!7Bb_u?_(f=`77T!wvpESNLWLn2op4GnWT#|;d zC5XFLP}@+>&jl>BcvKSGqxA5opmdu9UT>DKdj7=xGeydN_;c|QTK=%2*W}lHl2I|AMFsYw%$V*I%DEGX zTvYi>ezyr#Q?U;GFpxyP`Q@<_}e)N17Nl1x8K4142xBuHGO1m6x|awqcI=^ zb=w_q)j}T3rU`lNfwr5CNxqTB=xgD{OU!i6NBC-+KKx$sLAsTgPBC~ScSAZtZtJqV z?fLX}f`;(v&hK(Cb|8gS#bzWdZ17pZC`z4K1HSO7v3_@yug~LS6PnixBlcjK<(79E zP8eEcv~PG+W$}|6D+dHWqoc7zPjyuD?Q7wB5zbCM$dZO~*N|49A5{;V0IF}B@8D78 z9CukQZVYypcmtS=7D0E*sx>4#M14L_zfbQWS-Vg?l|F6|O_w;!nQ#gGjl(1%74EQj@&6f#i96{xD0Hn8B07JY$UCMZrk2 zgLXz=8WC{>!2n;n(`zwYg8%cx-hnY&KK>F7QzfAsT1RIeRk!fq<0xd0V}myZO_Tea zu*N|vophp;MS=TAz*dCO* zoq&OhzVbAE55ZzCE|T@60K^l+!=(A9=!%sSYzuhQS9pc9?vMS~6a_$Rp1CbC#|Ifa zm}&)N>b`{5UL0NV-(vvDyaVBeaw3WI6qMfc>{<%tEDZmN=xn^4;7U>nBHE9Y_;=6| z4P8Wi$T1=-Hh+ybBF*(@_14hHa|sblu~#&l&@|@@d36q8pr}GwCIK1`Sm89-ybP`d zED_K9A^=@nE(QTu*uEv_RaJ+Ju+ZaclcU}&05xw0_gzrC%PGA^yyJWtO25A^*|s&c zOIb8X~3n4}T1M{@BZ9>3F-j16gdd1(#Z># zRv>Gh9PE}abw{@#eb{!8oXyy5>nqH*l3_UV0F?tgWPTt#f$nJ&U4yIun+HG|cJ;Co z1?1HhHPTN+UfHm=W0+n%kPnmBL$-6Yn}(B9ilE95JW<1G88Iy@NA`>DT*BVW1uF6ou#!oZFQW zf7Ffj*?xjO3*;6N3Rb~pNH#zUC-004hGP2^4C0hbAEdc^82zI#y*Ju;PXnKSODkhI z``xuL+5bSJKlp-qAHVfCj+dp=*hrrpxtbTXJbA&JsGgU<5=pmH7yC;dl<)nq4xOzI znla9`sF?g|>5zL#Cu>dMQdU@fD+`D2Sl?*;N<)bdQt%ajB#mTHm-vl)k8rKm-IX&@ zzWd9RC41*=5Vqhgb>lat(1&PHV{DFIf(D**V1fFp>&gx#(gm83af812Z<+P)JY+z5 z%j>4d1!9A%m{Dk0wp=W?LF&$(YJ+b2L>VLyWYbH(>k0HX8_rl;QJh~o=V72T<{olI z02G1Xz&I0q5pR0e@io6rf9r-WT0I~sE8xI&=qdCswc^=7k)cjpAD!%O)P29hBbIq} zI->arn%|;{XW&C(RR(IOr735+Tl$0B&U*hBA&So1Es2?;p$(aDgDN3CU;(X8tS!(0 zIBrB!IZ-bdHGazco_uX?2%85EkOw!YT#CY{v;}y#`>2(nDqHT)m_XmdihqG0`MhCg z0sqo=_*7o#a2ria_1RgxH{7v_65N&l1~PbAC2>=8G1)x++%yRAleu+@HeUdTveEA{<6=xEprC12FF?K}Vg zYLgKEcEBj=ZA$j=Hi5qZHoMnF-j()vgNvOA(lM^>Z;1PF{vl-5JEz|lEls^bRF1q_ z9mwXihY}RIUD<-7)d!B$0u#imD{Y<(Xo_BGX7XkeJ{ndh;9Q+$0pd(YY5q5yEyYW$sH~bA-bvY7;-%|4K{WPs%h+ z6j{R~Lx945-$LY22M2Dok^JcIDSG{KN8kYMdjbJU3BKIGV8HBqPZ(g%6Qgt%(3IG8 zpO{cN|9K3V2gUp~73ZAgvgqdN9kwvJol8-fE$6uMFeDr>)H!ScwlWmrMy{A-p#GsG ziSMWyYAyI?S{T4v2kH$b6D-SJNZ>XRUs={X(az2S?yYsJ{IkJ@PL`Gl^|u8AausuP z%rW2d0={+a9W{>{u}(nAIEomrzUn*TsNRePPh8Ervj?9VJ8wV8S*dLl9fkD zuVLry%jF;cX*KUTgayYoZ;41P4cKUQNd^zE82V|}f9$@$X_WG`8r~=hwJPw{!smynyWF&rB_2;}(!(O+xx%RFE~!7yb58gk(!i1(N?d=9;o5#hB0 z5AN8|O(sZ76>o(ChxLhhEk?Z4itJWT7OeKs;dZ#;u|y1 zy1q_CKDN;MZ+ffA3bcEFM`rQ5e-Znp`+E) z;sY>n8~Bz<%(W2~?0oQUjZA8Oxdnk#m_oR&xIPpH|4~tXPtzU*oj*E4L+@7DJ2T|YiR2OIqk;YHK9q_ zEe5=pF^A<{*d12$c+iwylS%SpT*Z8P81q@$9A7TyhWHFIT%w?7M)F`?|MPkgSr0)u zm^tX@!L{TD}zNE4F)4aCOB@$}?2b~wQhwao4v$h0zg>~b-Shh1y8T46~oM){= zdFX#?i-sZovNR)Yl#@*$FwAV_w7MJ(;2+3kYuGhv5(4uARbP3JLSZ)>4UDj0DT+N=?W zuj{J)W4BOdk}hP|5frv!a%q_sm*2C5ip0!%S;acB?!dJ^!!HUT4F_ZiZ5TKt^&HavNXqQ*|7FkfEO`<9(OdlcajuHVg=iN_!}KrTyh#KB zo)n#emJ!rC+(Y8=LA8GjL>ICrHGNy%0i;ZQC?xBRD`HYDku0FE+Se3$X;TlDx(ymf zJ(MMW6>t`6=!A2W?BjX(iwsLV8r*R?Y3%)e=m*6>YuumvgWYOHim}T_p+Q}^EWnpb zJPXS)DA;3iUmO!0xrJ<$EARTI4EUp9S$>uHZ{i9;n(}`eJU{KT zSUc9WJ#kWddF8o+;uiPZ*rA6U7SjM7iP5^Y%A2G+dujNK95)LbG1u?cQ`G3RE?&R4 z(KiKCXBvP&yF#JOSRuQ%YKZ+*{8=b=3O;zY2g_vlf5@yuT#NH$@F^iSRekA%z(MOP zz4WAtKP~i@j-dsQ&14*#?o~OE^>4(>zlp80-E!eeRipo&rwgztBgef4$su15Y-_vP z6W>vPbCXH2^%6SY)ij$8-9Hph<7$&eb>}I9ero-PK`|E`c9Q|q03*ge!i015jy$vxeMr`@k;x*ZyMwrYx$b^ z-29>=gn6{_%SXqeTDOcxnE4nm zYsXWv`A(AKX0HL)g=5c{JY5RwFtWgnjisE=6_++v&q2)e54&RGo_YouTT@slW3!1h zFwGU(^g@YidQ(v94?S>9YG|eKz6`kLves+8Y$o(>T%_-%E1`XXUnEPa`8uAUZuyv< z;?a0wlE8@dF7A9`KmhcbpqE9j9GHJ5sIe)X1t-aFH#E17y`U4bY}$lfbja_dp!+`* zT%op6=o3i8Uz}b0@D$fA*oS&Q>`hy99d~Tm1R~^N17zRgMV%PS(Mfm~lI6MFuyX_}bxMpe?pr$XB=WM53fS`mPH79aM?DmCj1ElOhL%?r&6 zDbt|rs@FeCm2p)mp-u)!w+w{A{t9gww6sf16|UErKQD`sWIO27!H%x!Hd zG_CbVaFB|+avcrX5+QU$DWJ@AsFfQwnv-2@Us;Pe6}b8r+7L?AsD9r_ztudRJJ zu~9=Rj_;W?smiN=)cV}N69YQ1GIc&4jczvLN%fNJoQ5J+NrK8V;n*%hTo55{6lje{ zVVxWGTQkykB_-kX)bw@RaHkJwJlQbaDzU52y1%Q;`NiRj`{8*j7s(p;P&l7+wqojU zDUfpOFVhf6UX^OQh%4C4TxjLla$GJUrg_)hcHgDgG zn7$3n3rw!o3m{g!@Qre3z(s{tA+T=lqWS!$-MNz~)GfwE^Rbe_qqES7aMv!tCd3=A z-9_s8{X9Q+3<2mhPh7mOXKq94u+D;c|#V~gW{NFA$8Qw|hMa?nR2|8;=E?V7)96p8a#aXJ#*`?@RcR*5BdU~0{ z9`@!FWd8Oh`XN>D09bjmeP_RNdnJX4qj-lEv7;7sqt8rC?Enf7jX z$F`MQop4RQ=(Ees80$x8-&Kn&3=*1?^sWM`6lXMz&OPRu%k|(C@q0V39T`s`*XNVQ z5+3_fct@VRvN6Rkwy6g1%M0-VwZpat!}&%j1*vjm&e58^dbLB z9U1uR@YOz6wH2)1kmno$;+`Skd#9~X>)WM6#LlC}nZI|J+JVH6T~mJ zsg(6 zCc&j|!_eFP`?$Om&Gv%`Dmk}?4vNlwpliqXTG_NRDsskH>IMeXu-fsBTCTXJW$%ua ziHLM80nLteMnM?Yf9)>r=T}Hxp{ro{%zwNgHQc)od{E*I@jVN+sp81tqxJ6=*07uJ zOk-z>sfp&VNjb&msO_$9(wOB+Q!c`(-kp>r?-EBMBsx2NBTy)yKJZFIJwIuF;5N>o zsGRM(D2z?l;IsWCZqbL&!(cDR8y)Nf>6k|?VjZ{zf7^`E#+Pi;lIV#H&?sIt&d3Vz z9)BzPw@j#haL1Apfw|B+mk54bBmZ{<-ds+i#Vb!=zjUj@=Jg%Gx@#vy3&Q}Uf#IkI z^Zv?<^i3U)63%G+RLWN!@Jf>TJ@cf@>jDwT4GMJkIXj)7ZoilO7G9iPoY!@em# zUZ!!T18dr|UTm&)v_rq>DETQ4R-*gqKoT4 zt4sb+uq`e}aU*EMdS?_`BmI?HN+^->B)mC{qP{&>$+gM~v7^v@qbjCk_X_Qj1ifA3 zHMhU`!H*@S8Fr}#(#0)4TENS56>TzV`gIji*#r@%I36X5B={`e zO!o0R_2;K&^`cV+FwUU&{SL+TPs$^j{o@IxY>_5-_mX#c84&ZGu={f~@P8VPht#3$ zi#=3Pk?6?c+NI5LW1?P%rYe5*C@M4|cPjeE?h4DJ5cW9TJdXO(t9??GSe7T42 zm`6pG%^)~CF>KXujW;&SVVGfqb=`vkR;5(GdiLN*p0_xe@YgyUk{q!o`e+4y9S)`` zA*3>8t7^&*i3H&i&|L88-c73;RG+R|x#?YW9oR3G2&0KxU_2uiv)N;A^l z2lvD!Y{)~OLiB3yM2RNhBWN4hk9+T+!(z!Vs}7(jx5T%$IC=U3^v=WRk4JCX&oJOA z`r+h7EHsZrI`oFWU8H;s>P7PK3sno5OtwD9zWAmsQe%haYloU#QSioW??Fb;ZFe*_ z>1uBuV^D^|pgMuRLk-b1lwSrlB>BFq1K)QsX1qPD zqww9ls!n^7dhf4#@4DI+w)Gh>^kgj6)tufHn5R1#?<0)5QhJ+cL__;XIt=@##6H9u z`T>wN&WY>FLG56I<6Zz4)APgFRDG6Ipf^bjL8n*Wegd&tB@@ z_-(1*&Rd+Rr;Y8r!g5U`>?T!N?m_1qY_CM-ax!c|ovw?0e7rhcVu!pdRBs@}EJwM- z*UAdhwSvSHQCeNR=`EOi*U2RGu6;$pb?gT&mt|Ej3QKGtq5j+$5{$BVyh80o7yYGw zZr;5Z@F?ME3GhyWTaO>P0$=5^cS&Q=r3t}{^&bH956`i1YncA2k6{{M9Y8bV=Og&v zc`y{@e1^VL`R-X0|0nkS|JQTS9lu&Ba}7NoZc_x3*cL&pG_T@|c|N}PvfelR!86;% zawP2DwGbctggu-s*t)U1kBm45#G((Hs|g=l2mKFA=N;5U_x^2q?@I4Y5Tr?0ARr(h zN|9azf>dcDMOx^nh=2$vNJpv=r1wyyiS!bw2_Vu5JwVF7dA`3le`GQ{JDWMl&hEYE z+}GzK)*Q@^D}3zlPe#lE{QuY5gK*p4QP=-D z(^uILPyk`clR{Ml-aWZ?a#AP~NhRW)acUcWcHKEYr5;W!rqcGWNiZ@jm}D zPR|9IcF7`;LMm*uekDvD#ol77q*bKj;7<)6&T?yRFz!54GuNtBMV$u+F z1h?3-4wiV+-KiFHJFB{?+`wOPH#3wYus14-rrQGF78vVVl9)sFguvW6vVfs@cr1T} z)X8_LUd>8*3xrG(EHSyNrziO*3GuE#Fv*!2CQ79%CPRf9%R3m6WCwkc@ z;x)(^2y)1ryN|j<%lTpxLh7(}nf)#c~Ey5eX4orW0Fh z#>tn)YhJr}*rc>jNGz|H>XkY?CXr4Nwo9;hMp5t_+2uL&|F=&GOf#MrT<8qxRrj_X zgd4`dFF6A1Tnfg2zzAQ`FShsZjP^}Ih0K44mQV#qX#_Lu>BHdv4HSq;lw8N(xb$1c zk&HGWlkMr$uhy_w1X~ng+lE?n`sw=7!;q+@?*tatvN`PmbF8`-KuX;KBPdZRRI#`S z_dYN^Ht6M@PNvwQs1;&Oa+*O!6IsG8gk{mARpE(&2+<=%pVj$+KoY?Yy2e;~$s2ZV zq7Wa3^n}HrKm`SBGKqM013>TZ*~k866}ljkStfxEuJ3CMJ^Dn)s0oI3Hskni?o0QW zIHdF5q88P|xQWub1Xer+t8a86NdmZzg-9*1*9h{T!jrJu@xByT_%Foq6>G<&xKWM3 zGjgIl){`)U_#%Fi#ctzn3Q(MAAkyWM^R(YTsUJedejpGpJU1Wr5Q77*;>bQUyAhdL*O*A{sX8`KUP#PNAhw=RM(L#G=hNX%1``^P(HHQ<_y^` z;XX6E>fch?)(? zg%O5kpTE@8`Hv=6Wtu4Mb&2<+`drA^M1tA8cF3XA^9sXDHYhA3{d;(Bjk^5JPk|0+ z<;set2s{~E{({dTY0ABY?GS*i4tUxAf;-+0#$b)C~x%%cr! z%z|3K^nn~tlNd^II~6#(0lq8Kg|4fWpw;LSufC*LA3bx9JGesi=}3BpmB?r}H%^uR z74*9#w?_T6v%6VsxsDign)PJ+qq1Y&uKd!j9;O}kP~=3fX`6xf#a7t&?Wy>NWs;VY zz|_+pwRw*>;35Wcd+xj!r+J(Go2O)_mx}KBvS%)Z&nLm`{ZF09>!rYdZ+y(vKFlph zE7DAZUlW7XU*ZL5B4hhXh_d?TcZ3d_FB`cHJ*c zyGGg2G7f9JZnC=iEg8}l&SgJ-uD?eo5Tv!^XE7`UQ${bnb0v~2vAS1)A$hmnXvhup zzE1x)eShhR7%CI-QmnL?s7E8|Q@?X}7ID}C)Mt@gOqy71(}>sIQQb?UzPai=19Sf> z4n>sCMIY3vDvk->uxqV#&&XeMV*$dfLqf}nJ z4Q-Uzs}WlfX%xl65Wq*V(hQR0P*L5Czcj|uOmW(39PTonkOc! z0e#?Ian>Reak{^d79I*|59)jLB_R?N|lYxGG9Br#<%dJn6H#JvC>fZ9(V#62Y#CwWG{s$;QbDyxMOz=Ov%K3{@RaBI`n|DBB67js9m5@Q;5u#YdV&@~b-+sAt$~1#tOE~#g@$lUw)Aj83 zi%_EQk+J1{DuF(Fg!>UW7Prvu`tm)l7V)n7}CJIN7T77Jlw=noRslH26ID9%gKW*Gt7i%VDtw z_{WhwUU2#i8u8eiVyZMy*e45LK~;o|o4fY|^{FP2CkfRhgIK&GyhWwZ2;F=8ycQkg02_rA2`@l`j7e0+xCX)eQLpZ7> zn_~)%l&5yVDfm2-M(#WO^?~TSG-~L@-Se3J=ij~wR=ACm$M8w|-bOECiMin=sLx}2 zd01p@+lkrpaJ(Vyc~*6d|J{&+en){D-j$ezILKD}lls98n7wY$wwn=*9xDwNu=Z4w z@K;8>YAE_}@$k(WnrJf)={^l?fe~Zf8Goz<1EqJJubsdW<|?bX!gLuCdvUjC1mI@9E(8sVd{SL+B|&o z^=z1i|J;fB_q0w(_PoVvxYS1CM~*u1leHeMao2XyGFi;&uUQ~OknKZ@c9n$`48Qdy z_zu%WiacxLE@>Sx{bw@9yDbYIw@zOkIje&Z#6*(N*oLR_2Ndt$e+cZ78>p{BKMnU% zrG2+%8WLc`WsJt&pj7xy;5}ziGF$ku@YBv>3c20nf@6YSVC54XB?rrGVI8hkIm=Yx zIHW-KEkFirus|GnVF^uRdFzl$* z#h4M>DcP7ThUy6#;mMaqlU$56pr0B3%?mj*K)b5I2)UW>8etejVmGD&(k5h%zfTecxtg5xR2e=DNv7I>?SJkS)K$%rry4jg=N}VEM(t=o4wjf*JQTC-r2@zw`jk8f|Gmoq669Ry_k@H0+6_*I_jqI z^`8zyUn1FR4g0U6NyZP{8f0)iq95nv>s5iCwb8%d;F!iwyFQsfUa-mmxp7xve*+xj z^#iRCu5A#&C&#p6Q1xUQXri_m{32cC`Rq&7XTE66lV0U$2xq@GRXs~MiSvL`o9-;& zoCFterj3(|-Ph}zm4Fp8MQd>(58os?v%36t!1H}B3_1h<<*bqkwQ&W z`d{$Sb9)-BlsFw!+z|AYCsU_C^S&Kkz}noQHQJY}AKCi4E;iwfOoVP|-GhD4k5+m6 zB=>bbb~%t&*)K-^rd-&wKfY)!VURH7!wj=dZZyKyKXHL8AQn}QCe#S0OuyM`gC-eDQMn|DzH}HbtzrZ{q8cE8bV4F0fK|#<2{n8?k@eh%wPGd0`_HrUyIG@3ZwK4{y|dxX0ColhrXaBbKRv zCB4gob9t(_kiEYFrFlD?gF9%8ywUy$)q9<#Z0lgiqWh5)H-qIEujLQ8^OUj2<3|+u zett!E=?EJAJg% z-;miRVA<%gj;q+&5yMFXvUJ8|zxwAKx;gza*AH@EKgp5Qp-z`<)2*IV0w@cJB$=hhR=KesYg`P+Q+4`Go>< zsv*_;{`1KZ-pt`!J!>*`?BZQI=Oue7YBbd+A1AlgxNn|Yk_}TLIoDamnb8*A`jYCu zXy(M;fl=K1IcJo&)S%M2QN>}-53}5tOe*_?%k%IwnJagq77ndiCXVi7w&TL!py`Y{ zvUIyo;kYA+Zza84-5)BWp=SY+$yHN>q|q>`4{xHOebwqxM0{-_`wF>Zi-lt zqp$KBy~%;K2BNu-k%L zWc+(E8=zP8VBEBMQ zUKfbzhL;u9r={KSNV)kX9`~DiT1xB2{qv+yj#w=Gn*i61kJJ%4z?FM@a+gTSGhgzZ zvIiN2ZYnauc@98`$GePfq0T;e-Pu)aW}>E8Gy8(GumRHt;y$FM)}}#VN7R1k#gmQ? z*?7x?auDvs7EAi~+}w&oDWF*7#ZSrKAKD^{5h_|?^LsDuAF^1ysF#)Rw*Pr(P(yr; z{#8eQzi8FU>&xlYke|&8mUv0$^)-OQWTVL=i`935CDd;C=wJq-7<-r5V_R`Ne6;{r zaF?akOcg!83Vg9-rLOEJvvjd`JLnihsJ~{Y1ingWu}EC0mji6zK*W(T_R8_!$b)u> zep&zwzbxBd{P6Q*FyD3jS?(dQ*K6l-$ct*XSjaItObaYu4M`na>q7ZeV!GmsUYA|~ zIAZ&?(2qJP)vl}GZU9<(ycs8{c;mD$entAIZDzLcGP^xqTAwpVUmQHaJxA#Z|88#} zBywGj3fE@lq~kaQLAE>G@5cR3r0MKq!pLtlZ*uh&+Y{d(J|!-m{pxMdg}d&>M0*>9 zd6Rd=%WxQcI|PaVX9JiQLxjpuJgwFL%0lQ@7jg2zKn7s-uJKc`PEDu)Q?Se{zoEZF z@ikiVrfSzGv(M~tx3a5p9-&z=OO2VD;F3d!=O_4AIn*C z<_`Q=;`wy8HZ z%m&tjEh3U}M>|Oi`p=Ovh?dUeNqRSDm>Sx1G(iU%-f3^whp~$Q-oDVn;~tI~6Vm>1-pQqdkO zyW;Yog<8EcrZ|zi;_=tXi;qKJS(k0#m~M@FL^#3zY%Yi)iX=$#c@`)6)oL1!F} z$1|CyKuyP?tr5v$i?H+bF9)g7PWseGr-Nf&JUjxPGf-ytjK1fLcZINBV-mlfTxAeV z*$eL$vW$Ov4M97JuJ5@IopRJk7 z=6qbj*JoI%Iug?Ss6)bMtu>H2 zYPiNQBNtYkGi&{MsPhrec|eabHy5KbyGWFZ>H5lP{}nOQr?1`(_06bbNO5p=qqYR^ z8AIut@L59=UIpVMGjPxT3-w#MvxQcQGf!Zctc^5+T5EQ*+tzQb=AT1N+tuR4MHcz| z6jmbJR-UiJC5Z3Wrd9ewX&>H=BUvGI&gaewI_CJWHkZcm2;g}3fLk!pKE`Oz^)6cVMX^N27n zjvQGl9=_z?hqVF!rj8C1m9Hmu1s4ZP`d)m!#JRJ*8ih9zNbD<(J(=-mr3Q|^Y7gf9 zX2nj=cF$D7_sQf4(`M3@#BY#KxoDz=RzHJJZR5g7%=zukjm>bP-hv~|(ad(=@2d;^ zX%8ANzik>%s7)#|Y?DA2zf4+II3hagB;5R&P1J&;vb2-iO)WmngME7=eC&<L;Z zqIp~5>t+5GYnmI%A4l2iO78CiJ z@$O^t6|lyPe{Y8T4wqT(Mbu27E9N})Vd^pl**?HY-O_&g%Quup@mz;RiX~N&d#J|b zxT6=M@K<4!2FU>}JOTvPg3a*Iy8v?$7Ou9`M!AzMDs$kM@-prw7We0cMG*8#U6d)H zm`dSC>1}bZQ!V46t+Jl|PfJ~N=MPqt!$W5C@fAW>)R0d$6WuzM8&nK}!VKoG-q>sf z*~In5s=vI1oyoJlJHfCH^s(g1eA zKmr1f)%uHhsMsW3{R`iw;;P)>j=JujW&8-Cw`m zh~kNq_<%xB8p%;?D`b6<4hSzyJ#Ro#GA!$rV2+~FQ0aL~bQ@dc0KQ^jQCS6cKt`av zX%?|*?aG9N?*n~Dt@5)++&9uz0YZ+GtI2C{;8Iuw|4iHG8Bwi#>X4piq6s!28R!cV zwwCpBfeC5SH!0zZ5iHwUXOvad$PVTte{Q4RarAuCnInz2-@^4 z=0Krm@ZUK;_5AdRuatac4jR>$e{I@ZOBZPzZgvqKpg6603H-}bIivd>+r2Ew0dCC7Bj;06rZ8$Gn2^^KK=rG;+J>sJ=Ey*Eig8eSnIWQoo(6D)lu_rE5-^ z(vN%cj$3d6_)#-MfYj5f@|)Sg>YX5#K{gW6SxB{I-E*IUtw_F+cZdYHB32W5; z#Wbui;#^`#E3A97aPQ*vS?ozvF3XcL%2wMO_1@1^m3l5cME`JowF%NjS|UQiMXtra zYcQ(y+?zDmAW?tVN^{rz41rQSXO+RfTBfK7zB8|5NUWgtz}w{;3Qw=Sbfq}JcfiR5 z0G;`VcR6o_nm@XEGI6!yx2HVD4Ya{bo1k^VCs+DU3tYe`&|kET58Y#rlwuQE_VCV0 z@Oyn$kFBh2)RP#hb32)fJn6y}tAT-!tBV!^auiZbek`NoJqpWLjmFqE;?dota$uh6 zn7lZFT|K$esQb!W<$T+H3&80Ad$*7X-6yuqxT+XD4$+tmu3nG2PfElg_n*x3|G>rn znLQm^$@JiTXex3Ya`)GZ`w6oJ^X}QB>~^kLyY2#t3Dsh^z63rd7`d*vR+;jx-uROe-TIE>Qi@fTfnkH zA{Jl>Mv^3JsvL3_OKg@T!T(xU^oFtW*mM?R)bCe$LezXW`gPu{8ZK)8Rr<)o3Gmv3 zXre+FLZ_)X3BQEin>0?M(aEg6H`Vz|%1bLVMBF=SZ(tkR+m=<9?#+V`9@`eejXr;^ zSaD>iKDFaCbAL^ZqT+KF#2KF2Cm%n^a{8PG6+3E}ClF6%rza25n z@b3+(LM_o>;Dhu(pB=!wqD05w8DWm7_K`_|tq3#Z11sz09`TLBJ3{P3uem=gX4eWv zmQmulDi^___ekG6-@05~wMxx@!7O5Ub`Uop@})+mP)c3KA#9naNtA5s)IyI}9gO?) zlrur6-%Eb$hNwm%A>u8llpm!XVt)o;uqY$~d|b3 z;>90PCG8`a+~>Uwkxi#Np4h@A2T3O=f9Y*&t~tY1ubNT~_UmpDGj@PZIT)e&+r;X- z?0V8t*8#=>hyYlGr$cD)k${G= zPv83vfgfd-d%>dbJY2{Amh|iut}r>9Rl^Uyzpsu&KIzQr04S7rZ?4YW26$HcAFAv|1EvK5e z0ycsouCA%EC|m&^f4yPnhauZnaldO&Tvy2#V`(>J(_})>!@_)t#huh$IPbA zX30Mg8w#J7?$%4!47{{@_Vt#i1|{D53Tn$r{aKzm`2IKr~LoRgGV0X7OlM(W`u zPQ@LtQ&6)C^G(isQ7)KoxG3|*U}p?48aaTEk^x1!vBTyM`A2pnVo%=1PDaltEsIR& zNSlPS+;oslL_-JGjp6U*{FmZf*kIRIPq##QPY=Mq`)n)B-^GkaFaLbY%_qWa?-~l( zr_^h`gW4G8sNd-Vj`S5v#V1C;$UFmYv4mA<5%aH^O@!`R#a$miJ3a)iBk>@_nH^+g zPpgbsv(Ao+UeJJ@SjOQ!U$7fpM>b`}&Ey{!OwzZe7p&c6KgxMNL5K%Rk_vM*RS$Y$ zU2I8vMAXM??K5}PU(89l>8lsp4@dNL=|(t0Vw=8y_88Ck@_e-WB+|;!U|V^AdrO_9 zY>b&F5fc+)d({CSz`WS9<6s7^ZJyI}*|&z2h6<-wHAFj?gQJ?P&>fbm-C`Tr$BO7P z3i=3XW8Y?zI-c0)T(Ts_we+2*VSza9!7$Mk;B%f30_2%XuRzM5eFBQ=(=!}{M>3BX z_l%8RVkNvxZ#~_@hW;#TQ@#=^U$_V03*BWz)~A;j3?1|E@x5ZxttkEEO8r_Vv;>_t zmMD1|0q)r0+}JvX9JUul~D7L^}eF=ugK;{Se!cApJwvqGh#FlZNZ6mXDTlk z$3G-qtT3LAPRLxp2xq#<<;-$wc1q#ecu~+rHgH$Wrm0DNA6tpqPPbN%rtUfjy1dzJ zT_5%7?MrhL{GDDDD7Sg$1aZLf{pPgJG3Z2eBzVHGrkiQUYu%U3Nox8g-s7QGR0i%5*XJ1*-`r-G0(Q^@F3a^v;4Lsvg+AAR;$ zkG_-{Opq~P-y!q~hNIAbH+G$m5%4WWHiH~qWBkWpty1R5+GIz^}T2+V6$gZ{}Mm@T=p9aah2Z{Q;@***k} zCjdS1XY(QO-!}3W*;nN>U#-7F<>p|@w0JzNr{Yt@^{ayz_iW6MaHZuh-M-6l_+JY1 zCw<_>XOt71AOom>Id-2m<~C-=JZX_Mo82b)5GoW`Y-I3_F`N5*JdOi-zaZd~Kvbu9 zYjT`@3{tyDr+TIgOQl8rqPlBsu=w<1)h0ci4?3_x$F5~6E?VffRB!>|g{(R!Va^M% z>r*W@@GWtW+-}*PAwSo`_dlS&E;=E^DXrEVEu>zL^pzT(kD2-jD07KoT# z>+H}gQgLWWyX=5+uQzGz!8J)PG0Q=HD}0k%x_~o(Iq=OOLk{d`Gef% z`F&vcyRvZv{l#jdY7nF^r5C1d+w9Y*;ET-;I0rXz zpTN6#0NhUAD|;1NPIM6k1gw-6hX}$6D$%@;VlSFLWgWloBe!A-VmQ~UKxR@nYP`3}Tgc`OC)$vH zdyE`CeFwv}>FP0;Aarv^+`+(vx#>GrmK^R<8BLp0Jp8cAz?bq0|JAqVED!1hcq=w; z9b7+>H1+8sunlAKT!KALN!?BkUv1g1m^0s%%|pF;QAI89lDR!yWG_%t;M4Xj$2&xi z?mCuG<_D-i8!Viab_b$XDQ}4U=t>5fIkM%_t;RFI1@1PrzUi)1RQX(#(C|RfNmynD zbg^+SAruhV0@uoz0pJ^ihKU1hDjL~}-k@3T62IVfMCDuab2~-oG2k0MT{6=QQXJk; z0(WaZ*|~T84{)yS-S-g)m%jV~a9?-5GlT_V$$YE)Q~ZNg-iYj|H*RF zgf$%*kpvSVDu;_0T*b_|fPAuNRzuDaZUX(@Wm^97&YQ@6Vn(>V_=r&Vds3ewErKwu zd$+?bl3e;@I<5|v0BHJ!{or(0HHW|plBBu+;TBo$TgHtwkCY3_<%WO75 zu#+cz=s-TklUILPT<&bn<`O+M^KryKgjT-ZE%3U&FZY_`oK%CQ{~_Ie(j92JMx`Oe z2Tk#FE9kbcu4l%l0hu<1MQoFve??P~o+{!^rMd6@A6@u5Mcz$bi0&iaS%^d36{y-n z>+_GYH4Vu>yrY|Hq+H!vcu9|)-)H;MyJ$&<+5v-Kc5c{UTusaj&7kTa3JP)Fl#@> zc4L%J(N`+a>~36;uldB@V~uUK_^+Nh`u;G+Z`<{2mT@G~NQHTkC_}q?@t&hR*eIar zeB9{hlEe5E4Um1iQNmtrm%J16lqy%N=5FEaD=xAXN4K7*a1hB>gK>`%4jg3bs??EQzijzUdo{`rPV6$ns& zvx5V!8!8v4<^i|?#befcGIH)|o)e7cW2|(>6J6IhN3Ux~M!~OC_wpwig9HGNk|(A8 z{~B)u>^HhyaQ`z=ysQ~e5&G$33;R=)PZIhCm8_u&5u* zoVU4c0Q?rtRuPO?1F%y$Z*SAmd}I({egZamg@@(s9Ff2zo2U?jineVtcp{DbT8N*v z=&6WC*aG8%h5EJcf;R+;Nz|&1VQelpTAqfWFKCTMPPc&McgbPgRe72)$%9TR7Hh2BIKQ(2oEBmwgUe%A#DZ7YbA zeLLeFil6su20#*3-;`z=(N{}pkWdM&KU19K(QMwR%JOAU8MQ_=)bp-@<@|Gx`nSw5 z@^#y|g9hkEv_RBMGKy^_@}Au6YjQaK6#X{CEVRr})%aMN@CGksHi?Uy*NHf1py#rX}KFCw84?n~9o zKwD3)yn~EafhMz}KLEZ1!u=aN0*_^fYZ7S;jrb%pOTyF z3+k^5ev(bFk%nsTME(jO=Yn$~u> z4ViAApRWDYdXFz8!Jh;g(I z>k0Z8={o<~5pf{l7M`YjamHa;JE;<}hE&pNDH>8vQo*Ih+cfd&;jbO3Hgg__trle_ zeVg}iEAs3?%*{3LI$B`gop{Ix6Tc@SAto@+lXv0sx3Y|>#r|X?IBwr~xV@q&I$Tnp z{7(IK-;a{1_Db8n;_U$&w(q_GfR2BGQpcm6ATdgAzcqMs58r$Wx~bTW*y{?KAP|c+ z&*hT2#QA+QkAPSoB_nX0_4mvLpuzmZ6ncI|e3<9HSihMhdu{Py#p+fBL%R{Ig+|pI+@qMYLOD-F4bx-L5?bKzSqBye(*&ha+dty9Dc7I2Z&v74E_ zmOAO7&dDb=WD={Q@=6-rBysQ%`xZK9WdJQlLfJ7?*txHBj=+(;9#K~P%J;J6--H%^ zybOUdY<}@h-45s!=;e_R{Tr=6(TWD5@=m9OA9Vv}txHMx0(LL4omX)Og8CAPpGCs+ zAj44P*mdyDBHyualGv{ETohNDKKu^!my{a@2|e^hc=%HG>bBhFeIK1Y+P>);KO7Fb z96a@1z+j2~eb0Lb|HJ-rB|(j3C?wTq=MWL6C6Q%FQ0UdI_mz+H{;KVySH(M_4?zka zAR&xz5c?r4pTe$SZY>KdN)}}f$v*%j-q`-t>9Xg0P)*Y(AShx}ZBEdsY&;{SJ@ulA z?m|Nzd_9fuUC(Cv^5yuX(IXg*>@xkN_1Z-=R)q;b524aKh;8N3ti$&RHP}I+sp~@4 z9=dz>%&~gK_QQio^T+s5asv%{1fAjy=GA{ zkD68?GtS$@w%Sbrt|M|Q34dkO1EYhLWhW>`P`ZwOIJm90X_-_xJp@XoxI=v>q)FtS zXd#LV3)S@2Z!IYUCWUo3k3$Zlx0pM!l2Immi`siAoNc-StHx2{3Q&%x{li7W!D+k) zAl!tq*UM^0G`9IyXwp;esx9{Y+Ke7O@CV+%(=VBAZA9Ihw3-^5S$#HU@4)I~eJ-|g zPwXi_+T;c-NWm|)3pM=5qpT}B$^W8fKExY+NIq^+ZiOayIcMt#A(?I&Q1#mrraoC& zMB_J)9tSA)6>DnMKstaNEU(k$9t5yj#lf&ZgP&9%C}lOhzi1Ay^Z zzL&U5j&O(?MJbh*(w<&?wcdGw)z`w4l9u_qq=Da0*a8dw{}SIci>WIYO)NL%F;Vuz zuBd3i2UlosFtDtj+Y%mrO}wK%V*MeKVe`Cvv+}Up0<`jDMZesm0SV(#TOejA`q zg_?XeJQdvs=S{PN4&bkDFNnGW;v=9y1Ki%D=PFy+a%90h{-w352Lr@|KW-55SLvMr zRkck^qc~a3SgkLIcNpf=o-#O_<$KzBvotPAAugkod_PcB?(IzP0@wlr&Qx)NmPGim zM9-vt6si3r<$+exNWJZB{Q&1#hn~^ zQr+)VGcnGP@4Wrw%K|AsoyF8dg0c4!;#)puROwvw5s3KGDS4ZFWcHxH9jLktF$F=D zS3oGT$HH6>6Z%6dSvzT3`h2$R{>p(mYbd9YIgUjU9YO7p$N8n8pzmu(jX@MdgG~I% zue6V?$zH9Sb=-2=c(`l`?DaG)bmv7->IK>A8Nh-7EBag5=HEPgvE^Oy)oo9KZuI%` z%)`nk0N~qNfpt96HMm&`qCAOgHKi7hDJ)ay45xA?FK)1p`YL*r*{7M$PZdZbOF zi*z5nj8C@akw8Z-m1>>gjBh&13lcdq_dGQSK_^RhZjkkrarwoNA7jX^H}!`} ztMkQ2f0`& z0PPQ(B{OS<6LX+fhJVXECGoQ_C1{eF49&(^?GUGRyZmp%0;=`+J4mM@uYr@RsR##% zXXfOTqoOR?0b;^SB_bX-fHUd=-u2#kh!jx7?}^_Qz0p6icFTswM*@FB=vg zTDXdJJ`epF&Dg%OncYq81kC?XCFdNlWa!Jte|nAaOwjrKwdn|?@db8)f)pSpsCegy zv69#{$kjh@siLFs-H+6MqjZ?&;VsktT7|eI zycF=?@Z8`=1B+|?Af_{0NMNZDqwJAt9(m^5sc87JrT^;JOVGdxWAUa>0u zO@bz$54tfjTLAFTJJ{g^1buA?!(;=c_%zYq3n_te1Uc;xTR2Gp%0K<}63xy)xjcBzdy`dq`=>>Y& zSx3T~5}Bp)w!W?WzNztT))QK`1toM-*K1XbA8=6UL)}-50Ch|%t z&cNve^I#0|Nx0F)4XSYk3o!=mctJ;y@xb4W{a?L@uk-u(ukkG;Q~d!hi+|>ORLlKa zf5mMHMnFw0JyP4pRQ5F)6XaMHFR$WsPuf<)g9EDa9kiZ^Mk)b7?Wiid6OsZPOe7>z z$?6liQ2;)g%%)ezD zq*LrOUC4$2gd4pr6{PvM%412DF)_R(FzR=2~xXxjv+(dN~XueS|peNkxRd&Tcv4&1jvm`(hhWAA{SD#r&|0 zZR9WX`hq`SrY{E%@t*b%-p1w^0G*!XX%af(zE+cyuvct?=+^=$qpl*AaeSd4#0S+u z_JBB?2L4+QV(5ngv)LfPXXDrs2ms1hD$2xvm+TT1TH-#WHpb^ge*jRx{0uO9WHi>` zr;KAq0EgGoc(FBDB3}H+*rIUG@W0+M<)hS}gf~PVL5_y#hfooSLp!x)eN&4h9&nGtbOy#qP0ulG@VvE^Y zTlG;35#^jGCKH0?D!rRFFpY1$#?t^$?|A1QvMK^Yg-K)_XT&7?z049wf9B*ir2kj% zJcwQ*mQIDGyDlk;Ik7*D;sfWKusH!uCJCbW;n;~Da@7o3gV3U0B$;R_NH9B`DaQ9m z{L&xGhJ#V{3$&9)38ghv|JA%lF27QiL3&wy-jQnt`3zb^_J-VCW1`lMCyBQw)^zy8 zu`$-f7AqfP@x5qLSp3MUlSWoxt>Fr5&q?7(MZ^$Y!b3$Qtd49zQJu#-Q@PqRXv)43 zNQ{;pBt}qC)BSu17fT{y>hp~{imw|gI0>YB$3wJH95|ykt*p=2!*tIGqQ;-=5I9q3 zhSb`-1ji&1T!jF%9?yK~%A_@ON%WwKIpt@fG5x40ge0J{02#;5pr}Bt?3#D>? z8HOCDy8IO@tDB_^j|Y)b8&v%V#Qn$ofycD?>IAeH9W_4yDtGKvB?|h+{P9`T-F2Ew zy}DMmmFC`Lf=mu%rM9)z7oDL3j+Y#|^IcZ|p>kO4%?O0nPbbWsNR6!-6y99G)NRJs zM`XMxx8YdY;PZgj&Da%+1$|lmT^0ARnos zM7pU2S=N|i{5AQ2%p6+R2tj+M3_Xfy8Ur2f#kM1+%f(R`x&;N%Buf{M9&ZZydjqwK zXgG+Wq`_K*YzqLfrqURg9z=%#8CF`wB*D}`1xZo2mo#J9Xels94|Noc%Gk)d?svup z=UmzdM5o0who?sniDp=54vu{6KYSM`#U=PYE~Kx3>Qed^ssmvW*CQS)xiV9&PepX+ zfm;4YN2Y5-s=!`p2y;J(5i1Odn-upUuyy|-doiM(B1}}K;bNtqgjlcJrm^&GYhO{b zgS_|lO<*TIo&b_j%&F@O`9#5_#MyS!D~ zsVNk)g0hwXVxJuzV!V43tl@R=o<3V6fRrJa=Je_kf~Qz;umlq-&!wC)UQLft1Ko~(o8FVe^hJ7|Y$Ujq;{k#$sHdO#KRA^jb4=8r^Il~i zaK#v6qhp(3b4qpQOubt5&upP%PTz&tiuh{(?{KV_3iqV0$(|VYzGBL>`u!B?i6=0i zh6NLSAWBw5;bTN`b0?D)yM=!B8Z-t{K^L(tpSRJ?D0xHxqKI8NBJ96C6P4Tlh79BS z(&Ub%1CcWNQ!|9>fmn|OtgfkeB`39Ej<{$4BTM%=-M228vEuY0AuFqgD2}pDDkklf zlmEjT^V^;Dx>}+LBNT?kaCIE%@Pvfii{k4-hlIeQLK8O+fZ~u{^GgoHm5cvSRwV55 z?vqb+`W}<6L~{?N#tuvDZT2Kddsl+<2w-wd;_|6~c3a46JVUs-Q2}o5&*R*vRs1zc zv7-L1yW|G~C~DMJuylio`S>1{v3fz74Tk&W?$4#yfrs6hdlkn+#k`mib%XL|saBpw zMh~KkomzvP7P?myf(0dZxa-%@+9MEkpqp67?s(8)8hb*F$qQnyj%AW9qmc349!oyG zId=RvDaEM!_}DIE{Y@$&W&ZguLY_R`{OkzH1k0iRjjouk%;M*;U>zCJC*hFhL8N8r_4S^sa}e2iGf-V8h>=ce`R>-9r>AJ-HPwyB`A^`1IRM6~-)i zJQ?j9jK|mi^(}ZhW5=Eo zXFn^PZ;YHIzhmQm|5vP=2-p{~gS-|9Uruq%$S|8h zotu#$7oI7J^mv;vg^t0OCa6E|71w!)*pBQ9a7Tw!b+%G#6!lUIKV)F1g32?}_tG?X z)Qm2<6x^Y4ouwKIMCSi%fgbh{g+G@tJG!;w{=81WZLEFe;>)|yynEvlS4zLhd%wT< zK?db&WZE6iIK8__PvslFiRt=rLpiA^1^Q#Clzo}fp~}YYygX(T`;S@yd-Ex@ck+dQ zhL*B^e5KfQ#1YB&Fbgp5S@H*qptXM!)|FT%xpZ{omQ#7zIQcIw>e%KunLXmADo;V;_hC{X!%bcPP0tp2mO|LKa?44-Xb#j@tKx? z><0-1y3-M=4I^T^j>mhati!E{k;=9tocM13P2eT!}1@@R?!ci*7?EBp!1Y)!_E+W^oq*$JaSJ4lL9e` z!u#t-mr#@@H^F)|GFNU);1;84N7nkg{@yE{R)7F8eWkY|7bDI&7tY{+@@46U>9k#p ze}06~;jWtCU!3`3!RtpK&KN~|5|f+X6zP~I?3k-H)3dD81-N`u>t8WfmK1( zW$+_Z)dO2niq>ahPWsRW$fTPJu<+Ga1Y%3XYc?R{iR!seP)^c6vOmY<3YJR;d@}hUTkh+5R)kYNM8fkT=XVQ?Cn@>BTfP7<^YT^#5o=zfF z$|*vfya{R`drIVE*Hm&xh4@)TOXN%gEtQTT!40_pGFHutb(p11ASLc!ct{?a(7pU1 z^q|!*_!SuS_DKH`97gUoXJJ`Y4IqCWnYFWsBb{hs{~uFl8P?Pv_kTL18$l#Qx>Q0z zx9F2f<=Ri6~jlt^t`Q88f!M!KDuAMzNJJ+tW&%WRH z>;2Mr`iF;%8x4B_#MiL-{lE*9n_=Ihi|W0EUMDk@VQ0Y0rmlCZSs^|?=@*1VNhx6# z!NSIYA8Bza=~16FFh+EHayFn1zS_X&Hb@X%BY%TwpA+gdVT_~9`6~8E++oP*I|{&94@os+)**DP>Id8E?sK=s^Rq30 zJ&F7}%E`;d;{!ue!B@Peget3p&$gSaV1d9KJP#Z536Tr(OL_cDsv1Co=Yp`1yPedC zD7qIu$N;?6MljV)Fj4xfwUMh};bqnbvqJD~=rQtEYW0P6)@3A4gEwNC9rHrI-P_eM zmGT#az}Cm&HvjOsUF@P>H|wrRBZoan^!39Vq;02;3CBo^6-i>&PDH-?_;E-~YtEVvA2b9ruJpc{+%; z1%`Ut$x_H~y#3?I&vc`-!?)xbLQb$Q_Ru7mGJj_KfY$Z2qV}ab`*QkMq5aS%NWZjd z{{+9z5T8^(#1;C7PG#4#56dCYCyH!{q~5+4O=5N#ccrk4l34*}`#Y$x1175;@e4p` z`7J|R;8bu2gR)8hjj`z6N2{GCXCRS7mPhw;`nq({CusR2 zaO2u0sE{2vLm`H+Q3l1@bEf7r;$Q)}8K1IWe^~iH=c?eg5F0+-1r6dv2pGE}(*ee> zmT-qXR--^~P?%ZM>Js9unB{x&XVm*9zjhN-H^i+WdqIZ#BeJzZMpQiGo!?)}e(l6+ zKD0jU8tf1cc`4xExdG)HgRaE8zvWl}C;9>-OW<31{jfp8C&k*b8mP(YVmc|Uzvg#V z6!#AD2OpO1Se#Na`BCf$#c=)H?Qj6bpJVB;fsd?KD@@2E=4VfF9ErJXSN)B->b{k2NM{PO zr@QxjDxWe=!RXjRuS}Uwe)*l4M5TU(s^paCh5A_9q=!Z5QXp+a<>Z}$d+mjJtT3dN zeK#2`6M$pIdXMpRnaa6h@+B4Bi)V@QodusK42Do5-}=r5JHqg!+n`v(<^xcKFL zjz zw7f){lCA!CmY|#$^lnoM%mS|{4ug5xf{vVVCU|Bd(cb&n%k*V-h>all=e~c{WZ!lF zmIRr!eO6t9&W+u!TO)AmM}{X}w5J8KyWkHs&siD@_wg>?%xG8Tx2%|; zYI=)Y%&-CYx^3`eKm%oCJ#-HF5iPY9^|&pX5f1zuEpb0nZey$mCll$|19OS-{_gB( zm=4SL$BH1JE$)qh2NvVwFX;}R$94HkGbhP+?jwkq61VqwDJ6<)@6OrwUciTGTy<_6 ze-}EBrW*6k&0qa0b87|-Z2y2ou9Ab=?RgXn^@Kg@K4xjtJY;s4C2{*pxFD>N=isJi zPgSvri9wQ`rWHswD8KrcS@H4f1HX3l(>Bqs?`-WQv{D(rf(6WWd~VP<=Ji^o{>}wNV9v?TEg z^G>-7AaS%JnusQlKbg=AR=<_o=Liwx3;R|4UCxN>C}LDD_V2}YcQ%l=DtTr9{crBu zKzencSNTgxJ1<@DE9vw|-v)eQI4D3?sUHW&OT2_crJr27^-e(;U2p|cTH+Z>DsRf2IC}O9}F?f|H#Vw zr`LD|eCFjp6=59N>o(_)i~OU_ z8DW+Z(_YO^2;8oc-AyR)$3Y2VH?-zJ;vB9=0G6V1Qv-01$&vW6TxtHP%8!CPO=4GB0?zcyR-@A0DkUo@tRcy=I_x9Ao5Pypy zUz7F7hBVp0^N%3)4@g3pioXtyB=A;XuGvo(aE4!010B@4dt1{99Oy{1}BpdkPo`c$|~q%3jsw;am;H9>Ube*Y^XT;!|wPk!G#~0 z?lzjjPF<>zz`Xp6_g%oQn#bO$Oj^!A>Tw{vIhS|YYTwo~Nd6Sf@PP&@O8j=~P4+P< zOW7;@T~*sdB+i1Kp@}OhB(OPy=ke87oWtm1*_-6$_f>ACS)>xHEZ*ob7YxBNrDaMi{`|U}5S9ss1^?BRk&8Y)H?m5%WidV`^Lk(NnpyGa+ zgkS0}8N{CN)`x!frsGqDPS0*^!QTVsoV*uDv5yYYJBhoyH9xU#d!v_cnN}Cn6y_tF z4qx}#`VRdeSKDe@)AtK-eZ4od0GPN zNr;PddC3)oarFWToN%gA))nIA&-uflLoPh7n4dux&$eyVtosmWETiKm!EoUWap^q% zKM(c~*Pk7AXMo2rMmsM*KLaZ8na@Afz($s6&8@-k)W^@bKQlUFB%4rHsBD`Q{ue{Q z99O}~Gj^#_50W400D|-F1V!8=RnSp=7ip_=uhiflZ`iGif`V+g%A6*bJ#0cut0bgr zk!sPx_DrfN#Iua!_!6A$)>rPL`E->k{C0Zp!r!_7s|Is}xS{zoFx6|tD^(sl=N>Zj zp{FEBQO&-e4rM;*dA^B%w!1$zKo|B0R`_PA{5oH4JQW{BEyATfgAJ+a5*=mB?o!a( zKgPLO1odDOK~f$m*dG063>k~x&;UaA=zxiH)h83UpDO5 z;9^)kuOEVZ(Gfp;@)PaX!!GdkowMYBJm_ZyvgpuGLZF+eo@cLpn&%6sN=mVVDxyJ) z#0p0rgzx$NJ7}AtyAy=m#jOm?j))m#bwI(b(?KRKn7o>Hzj^>j*;ISWAno6Fzr=8e z2>^GQ7Hu{8!)nHP+2Hw5P|YHJ&HObiUB0O&gdm#5%0jZgIfwjof+o`6aU;a9d2opE zBl_7@QnbN4K^{r+UlbT@tJPJ-(ecll3(NDv`QSGgsEpI`fZwZMx?=Zhjr~0k8?Dn{ zfMa6mdgNNt-2%LkbJNSa|3Ok@s_vKXUlA%5f`u z^5eJCFPi+so`iu_?B6U0S6MuQd5Rn-UDefyp!qEdp5@Uz7%be&W+{q0n9p#TAom|T z%++X!2hTkTgTO5QVQ&eS%BjI#Jp>lJ#!|Pw2?;UKW)t5IG+%bgy1tILOGUte@$p6k zEIh=6owtg_fc`bbQ;_e-iTnm zo2B@U0yV(IQyBBRmpRjw+X?Tcm^dcQEdQ1L7OPz$(2VpD{N{PgJG8s<>)YG=Ytdc; z$E97ZX)q@OmnR&`nBi_zXDz-q=LKlL=qn}Wb!}3-bfg~q=J#sa}B0*?}*svRiqRN zK`Gg)>F1c>&z{b+RR6Xbdg$7>|5JPxxSi1YvNyus0FCm)fN|VuKXLYMFRveSAzYQd zb-@z&ldF${XW(-bm^05$(S^gtg<__^*MN+5MQd(ny#rLmc%Uv9YYdhpyZ`WSKW5Uh z7i66ypt1OcE7NSE3k~W?l$%W+Bv6lL;lSxQT95*?Br{WkD-RLRx+aB7EBg>W_?+=u zqyCW<4V8s1e_O`DpTg~^AC{P3;Q5*g#8zoS+$7|kj)gamq73wa=2>rR-QS->((F&l zbzOaB0V%=zQ>a?zB8pPnKFrM4X{J!Q6>}jJHgK_czi&jiln021#3~xd6pKrrJHs%X zjP0p8@i$#mi@#Lv?4n5m5rwI_01`_HoTxq#vJdu70rcAH2>$i|T}&qQ4x4%8EcgIG zMl-#_)p=vl<%G5}8Gh8d|Ig^q6^m8)*EMRFOeV;HFn547Rs91Vv*k3mc2FFVvs~Bg z^Nn|;jpy6eyeD-8w+EIkpSk&dzd3&Md->V#>s?)8GBA?F`Bsj+Zb1IEtV%P^Vx44t}1&Q*0Q21ctuxW2TwpHBc3K;lCI` z+8?(4Zss{Fo@wFB;oJ~|Eu>cwIf4HHFgGadh?vo)YIELGpTkbC(!A>{u&wUc3;;FW zo1?xb3zd&;0jK&adR`jG0^W}7FMP20cZE5lU73jK=%b^~hMYqcQm1g0XpDUN*6wKe zLopN>YDIIh52Rr>E7oHA+aLxnJ&%{39#q4i#PAQymX=fZb>0~zNr`>gck~L@1kYJi zyqgU*ChAXt2XXNif-m>a|MB`@Z{*cmo{AQKM>3kN%ZU+5YjVyew)ad18W<9OtD!GF zz^7Onh_~Fdm}|L4lwsd3!F7`u8Dw3~x1M9bAOLmd;5!MCT`lXm>cC)6CO%D7UB@;q z--sQF8#)AIkuws7jPc$Al_2V422Iy3+k;o?RgbT57EX`GK0iLlnEe6CmK;t&yJ9JX z8fs@}tZQ&GK9DQ+Jc}EWvX<@TH;MO@Nli(PRxg2^-G|L3OsjCdR{hrAEAJ~b-8`R2 zau@HH3M2v`;PTjU@zB|l`#GSyyg_Bx=5HyN(mgM`@KVRRb(T|^yV*gfa3Sj` z;yGjh?Hb;D2MmkisRSpO;!+=w2nk|ShJoKFk&c_C|k&1Oi676(cUbQI4^%`GXngy|54IcxsN&=5Ypk(fT_LFrm z7~AB3^75yRs+0>DDZh~T&>5k3Z zs7Jxmof2&i=0Z(R@G0GDrPsfF08M$|<-qHaMLw;XXV~rW+he)nnB|RK`NFNrd~*)e zu;c|;8j}3J)Yy6K?RSa&tW;LJt}UQCOKyVLy<|^ z8tsI`d8JwYE)8RufULA_3RNyHw%-88y8u|RK5fSw4e$tXf}Nq)VvhhXjmtP{qKY@k z#QvQAc9VwkOX*1{HH%>MZFTT{epDh3KX(>X!`y@FgzQ}TFr;~7M>zhhx~4v8Jz%h> z*R5SL&0T5zXWbxgwp8nwj{><6@=Xv$xx^rLc1i>b?Ki!=aH>sIb}OwVgZ7bilORW$3o}+F35pI|pK}5rCiEW@WDANAbd}^r#Wt`NtkC5p!=Yjwk zZO~jkgss$<;rpD$18RcB$F$1GCGTHtV8Jg)a@So^V#K z8E6<*80FSrw|D>h@ILPnZ_F#Nz(6CLI_vwfrX^%8A^9x)A6lG(tM~V71@{pge-}(! za2{4Ie>Vg6nO^=dtAd1{(hk?3ihWV5W^pxsJ88I!+c^%-W9q_QpST6du8Au$D4*3M zSKg~#p{%W(7>KV4h?i}@E9)f46K1gT^=)*9=sP5+)=G^isCdP;ALw$vOw9_RTED`l z2wjY~)P_4_?wjFc$#u_0vCJCP(lsL`Pq+dV$VB!Ayf!Hx>M%-nBc3dw8=zoCb8dEl z`NnX)&xA%U%z~(q1f#lZ6c)dwZNUNg+UxG-rj5PC>UqTW3pq4$v8S~u`3Ze%+KK2T zW=#!`oh8L~q7M@sy$;U&f3hwYfFxW*EmGs&rM)Ke2t+0Bgg=#LZwYOW`JD*HE*9&y zz<{3VDFHKENUdS{6ZZ2aaRFX3B_vUMZ^mIA8!(LCh72zxbOx-S9C>1{)MA|yz`RYf z1$)#S3R4dlWuhd#>wkzmC#oBH3H)tlFx^OWp;7SC5ZXPg$5~3*vYGP6mDBRoO%%o_ zoTBS%@~A6XDczI~I~Sgq@L_5zJ?@;HPS%6%;_vU=iBnHU_@twM2t0n-^p6QPmb)d= z>>TQ{a1BPgUZBZe=3EirMZ6FGt1)#iht|H&{BM6SF$9JMBVxH(XAPcYZ1+?2Dt0`K z_)*G+QZ(5`C}SAietQrRmAdO!T7L^P8)isRpyW67(3neX`NTh=1P6gHAE>OiV=dN; z@BOgMxhR~JZ@*#ccE!qqKsUWZa0oEsK0+$@%qZj;6Ge4OL1*r4&zoVVQDLiU_u8Hc zm^^mP^8Jh29ifmV)}WshNoHJ-0MOu#6H&5a?;Az_Qi9H8x&mb94MEk%*KF~Jh#laj z81hXWI2#TLwU=YyX5hw^-TwM#y$0OPt4B-y*VP!Z>35xO60fX=^UA*sOj~r`)^q9J zvOy4bRTr>10fXz%ySRQY7iSX}mxt%@2I> zjc<=1i6&_LY%$d7w=vc2IJGeYTB&vbm898@>ET(G6w6||=hKj+%iRH8*>w4G?$%?7c`WNr? z9Mti5@JYz^@1>W@m@}c@E&=cu9(DSGiG|fy7jil~rzRCO*ge*sCD)r^I~*3+f<;*B zk@isD$7PQ-mE|0ITKDT-dC$p2k(U!P9tz6@`PM=|fph@4kO82sXIQAz%2nQ_fhB36u2szPMDx{doAgwjH{VvQKOEMLD za0%7E!V792l7H{CSkZLec=xZgq#-Oamc=>$wX$X}V>+glVnSd0$aA=!@Y4gwTcxKP zJhFjH4U&lhtqAo9({pj$tB zgjQYZWYvTZi6J%2Y*?)AKMVYP;zlacZc%~P+AWl*!6Ql7OwuC6;@!^QALH%-4HKXef5(ZR9rB*D`2 znpnM?b97?hjx5Yy-YX2l5ch-uu0Nl^QYB$qfl^8_l|1`Ra<Rvs^kEo3oekP%|K~4ia-Y8hrA!h&45sww z12-Gn5per+GgDvLRH7(Q81p5%U_LLuheEK^G8e&q2JO4L^< zBiE%CSUF!SYhIAHRnYuC&M(oNA91LAtLq$;37t7T;8Z=>9a_Jao&_R-4Zdfq*LDS! zH(Urxx!&^YVMsbq55N{*?a0P&C}h9m!BZ}8ZuZq^8tg9CP63>1ir!Ur7Ulfeep-dn zHj8x(5`b&OaMU}V_fdXprP$PlVgl3cj3h3ASh+4IFyP3MPY_2>vCuC*3BHRuvbfq9d<5mS4e+shS7vRAh$ITQ3@h+rLT^@S|ep#`L zy*--nz4$z9>T-J*L0t(=zNZ=YB(8%$KhzdxqZlIHV@KNJmA!)ZNLFTDouvi3SDE*y zu@>5-P~NfmOv}Tb>JQrmQO|3kzhW6v{vMQmQ~dJXelg9Jpy1B^XuS--HK{C=gX|TD z8AERoN8qMbwfV#`pTnVlCkJq*J4gW5w0S5Vi@%kxws;cZnf4u45iU1_KL!`Gr;&VCph`E&n5AaP3C~@71QI{joP1k8aL!$sm2L( zjMa?HTbimi3qY7){<|7JQLj+G80}{4hrIiT`D2(zY2gc{f}h-i7lX@If7v%Cp-(E{ z7|5m5hc)mvWDWM_)`aWQ)vXMF8iri>7+A(?#`y&LwCkRqdokuGJ&jxni^3 z;^V^PV~9xqbDh1`$Ukej_4^0qI9d)O>n7)>Fs2ZP{(4EQ#Aa!lj42Ged{3>nBxY_R zyTv=^d>@RP6+fd)3^bu;EOSQGhmBvKAHTC9cod}V$@$0|$ypo~>OJQdtawI25# z+UQXDozFX{LVx;Rg20y))EC~mN2uS8r(xAUK0@uM_1b9j69P+pw_ZszW(7~xhw5n2 z5L=Gp;vH3>Y$Vm~;I|1_yB|;??6u;bUC(3uax(S_A1LcD`hIao+?MVu?mu#JRc4kJ z)we^(IAO_pE-Tnm`n`iPWPpS$%O{Lw^y3mU(Venye;I_YnAz56P_f>3E{C2S^8g}t z)vQ|O$>CqBggXlDoa1{ajn@=ip&A$TGSA9{nWP?D+>yyno8L^-VsO|Bo~I1}I!;|g z#)Ff_DqT+Ie}`jLu>GoKwj_44w+cI8TX*A2$h|*hnY_*Cg;_Ca4f9&^jZCmJ0a5+L zLcKW!W67gOu;pIEJ)Q+DHpU35=6aaj*%Y1)Y>w5~5*%-&b$d;I58+jXFXEX(T!X{M zeg1eo%Fhv3xj`5GFq02gfp`3THdFX7{&=vf=Iq0|+|5sRX##o;XCGHHa5*Mmm-20H zb}Qt#XSDRT-1v1}62nbMs~!<_M%Zwlt?aB526jz1^sZ_~1Y>otq&}!0-qk7XH)=aG zcu$)-{ha{;-LLZGw;XNWILIIQa1Pzx6n&T0{tFcU>DMI1DJGtfL*ENy-WH*L4 znzmNRg(_hDxH61?=AkSH0L=64Wm%p=8bYHUu_(|{wzLB7P>Fwz_kEY1ucn9j9Q@O{ zI9zOvc}`qn?LT@yHHDHs4YtI-I5VikPfs13?idfRRh*8xkiSV_D!%yko@i^f>7Y30 zwkh9!x2MoQNJ9|z#-YV<`z;GMe32y{e8!+)bGT`Skv{tgSpXNKE|JYa-b~RqypR+<|nb(-$_rCU^x~7=L!nx z=AN;TrsS^E>X__;#IFTcXsKEy2XY!h zsG>|InVz7W`-8Nii0Fg;@D14T+6xLpe(VL+epltWM+}|j1!srO^hUd`$cK7xkH=RB zzOCW>oRGWln)HK#gnC!41TLfsaIy17`+vu{YEP8h79M&45X3kHihnCx{Kq}@y?wCG zW!K+g@}a5;kp%j@w~ey z%G-tmX0$7Qz}jG{+BQ_ISBh-lW8pvu8CT?*2fwKV9Iq%0f~fAT3Um;OhboO+f9g9Rv5?nL zm5e9BRV>VlHN)V|sM^tOJ*QA-qLW^LOp^N3zJ5MXNJ1|3%YerTl53v=@`~IdPEkGq zL0k_O_7F!q+9M(uEas5vNR`_@@m zJHK0mdjlap^RklZNev4OeOo6;W?1^zWH!cbgo0aTwEZw;5i|{&JCHqbi+m2x3>*Sx zez*R${-xE}X&;bl5o=-msIe)fg~x&K}w zc?+@z7f5afaJWNL@=<5(Cw!Z@eP7_&-aRR7*uizS<||XDlE3u=su7>mYss^VJRl zo&bN>SK2MbU39(jEkD(no@B4!Jg#_SsPsFiH0uDUzZoF>Wl9#R?3;F19{61<^VVyF z@%7eTo`sgFC$V^zlG55}Y`tODI*y+T9 z&wu5$A4z_NQ^c@5B6Dd>1pOVO#?mH5B|hi28p@{7_F6rPJE$BJRex7I&6t}-n(!e5 z&&=3c*KS~f1RY3}TY$p3SK{&TQ!_LKocy+_1Wlb)XCk_e%A<2B%pHu?SFkL<7fn#! zlOl2rh6D8T_{jdMT-+i6bR+Ct2j*R|3_cgqNQ$FQWuYM(+`*Hd%hC&q2M#kp@3t!n zU53j%OV-^qBxlgX>mMvam3!Nh=vX3ner@^?l7>Kjb|(O8dxR_$G%h!G$HdfySa0lC zD>A)#e+=>TDXUY7hJ5^J0T^cF!XU~3>xAUAbB<}HKugq%I=W!HHVqiS47?i@g3fg& zVtq~THtk#acIuoL9GWYqA#Rx~ZPR}B1OiSi zQOu=C5;+MRQi~s}(#e7~XU2_&(B|&&L+fosUSQ0M5#@>8& zi}+Sl;I(%j_@86x^zROf)`9si>U&phQvGvz=aE(?lYre$$_~ zQ9UrSy#ckyVJbEsNAarG{gjm#5U`^Mf$WkG@VqQl95XZB8Y`%_eZ?&{XjnFyUV$(z z56KGawP910i(aRGVEJ<^l*JR`cZRz%cBrr%q6}bzJs5Z-16p9=!S&qRZlBFB^uDA@ zUe;F_vU%x{K@n_zfJ}p>iW4fnIK_W+kV63f#(vinaFE0LyOaF6ymhm(yM=RG^KdM_ zfPhlv6@ghCDkS#!1Y{L#|DG*P>iN@b$|K12hU)=m#VE|Lr7X+arKu3l>|QL+?i$Zm zkm)EqLoV!Gy|X^(D2kd*ZQNiqzms=Rlq21?%y5Mzz+V~r)IpvDtahc1h?F? z1R4xGZo$o3nq^eTjp1dc^B>i#DV>_B`5p{NYONiqRc!DVSS2byHvwby-uB?p_9|zx ze4x3q!%`odC)uHIm>M#pRbVwt7IV$1P}fw(pnVn0FL|SJ?*kFXRx2MuUzje7 z&h3o!i}ioQMHv)(`+>fuq~m)pp6?Reyee@~VrPKa_}((;wQ0&OWL^UO=FjGBdB5KN zZ8a@EP|tOpyn0vRLU4FkU-1rpok7t5fPt>pjN<}2@%A;D{NC19JuwnDIS=anXBrDO zJ_h~pMng^n@6{E5zKDp3=O)%B)7_YOyk9-pd1vKL8XH~ULD=HdW4Rn%6=;O<(c)|5u4VOQz~e+c-SR7E{y zhgv9Ujd-2{*L}$SbB1qnAOhM!MRJFPdaE!M$??D)0=H=j+Uf0;&#Pr@tS-38TCL6z zyn!tdiRJ2@lSuwWimT^x*j|eQF`jO}&$|iDwaoIpBcIAP=5@fVq^0XlggtGCOeeiZw0MD1`7pW4s)d1C%U zO=9bBy&ugmj=j>A9g6QwuWMQNg9yQd98p)Q(uA|70@EIrNLzKO%zp{gVk5+%IN#yA za$W#F8qBAQfWEV#MRi|pHY9WDi-r{Lo4r`wF-J`uzv*#z&EO2 zGy+_wZd}b>+}nuaUOF5ElHHFgrs4ZyK-?buj0NlHjyZpYBl`f(GJi_s{zFZ9BIMI# zWeRp?ANQi>Xu`}aNtjTVxs4d2b-oV1fUZuhw`IboqpQ^Se)_FdXQ+QHz)nASJD8D1 zNcH*%M&69`P&Un%7(l>i4Ru^Il?4&fv`cb1P{qAYA}dcP6!-;!ww5HWR9*o|WE=st;c5F%s|n3Mc{qJ$ItG-=1&R)%Jl1IsvRws!=tlaQsW@ z94q++a1N#jc%m7uMnG3%80UZOnjuVU2Xq1phrw%-8yW>Wp!xR0%h9_%SLKGqj-XgA!ua%v^M46vMYs(tBMv*4+x6+4Due2`YZI^Ie zqdZCv`AV?v@#8AuPHqi-r%rI63}8|;vG4=zKp8&2&!?boS#^&h{~8k9kYJs3K|-)A zJh8zFPj)&9TDV>#<#@}z8VL-*Yh+>D#Lh#Q=4b8Q*qVEn-( zYT!PA-i?-hDZT6rcTc$_0mgkG$^g!D5JkX?v&ew}iL<@CLpV4uWEC?x&|ME`aEL{2 z|LC-5H^ZKqe_Paw)fxrDXQ*&E=e)7cUFBRUzReiE?yGZjLH?C%NTfi-3~8ISf}w=* zJyZY;?+#b|pU0-^^@EqVl4>H&d=jfo%9NT}xLX_Ki@A{X&9a9z=``J6haXR1Hx5fq zMZho&HgsW=G<mg@mxZ-TMG76l8E_Z9z?n`-&v}Q8zxJT4F>v zLl^SG)s{L9hi1)>N9iLW`C;OB8rZ$B-KMCc>pzdt3Md9mjU>eoWy3N32`B^$ykurN z$0+@hqD-H8_+)4HZsUeV>5*L)^pgLNB)>5P@s*>Ol~?}h({9p|11De})s1JrcH!Xj zrB=I!#1k#t-K%d<4x1pny%87PWRFnv(B4FZ77=RV->-~^9Rc+f={ZYK$o5OAdu=_z z^zW>U(t2>75pZ-}c$#A1AbgTXT;(IV_*6kYf)hRQn}jLkf|8j5B4DRbHb266B=&|` zOXq*MH*z%7CJW)iv&dd%vs1{QC&o81bDl;UcRIxycz16{MRy;x0)2*N#0D|kR;2ZC zaIe9T)8*h9tQYy0Gn2ED{v6)w(TzS}Q^6AE-QK>o27SbErlGwK+?|5#>3=rHTD?1( zQoT2q)f`DMyfAk(CvYZKZFckEOV7g~(jfX9y-9A}3Ig^s4W830?>v1xSk1#&m&Id9 zHY!--ORt#nSxN6UO1}?y*sJJ=*L9t4EQ%)QWeA|_@XS4sKLBJ1ebH!{vy#3$vdtV6 zN?X3=gr=o2vVzTb-O+ISjC-3k;c^eERZa0T2$c^vKHQ;cy|4+Yy-n9}Ig;s+7W(~{ z(jfeMp^a{0=sSo|wBIeaH`8n_(XPYp)Rcs+Xf*il zJ#)hwd5m%atrsUX1p3PETPHOUT|{(uVBr5J{vhFU_bSfRIfT5YdzeNy!Zz}0X<`2B z5Q0vq`?eK0RHnT(F)idrzDG3wqMdy9KRNmvv3DF8d>lczscFbLQSN`|`lVJqH*Hfd zj0YVwPn<=A_t^X!1pG_KqU(lP@eV>QxB1ULNN|fr^Pjna2-$n6xr|oMl#PtG0K?J4 zpZ{I`|Hb~6dm=|ST@~^}h{fzla0GjpyzM|1KCUtdoNq$`4?#15pinVhv3lq%t!vPk zvLa>mot=_p@A4dkN5l@zvzCE?f{OJ1<06%vWC)*s2ZvC$ovBcGe<|{LByxIp)G4Gy z4fdbv@je=ZW==jR2=NVZ<9qH95qgGqav-z`P*pjv0ermv)1PW}{xnRNot=_&nm^J7 zE==WV`m9lSx;-={$uCL~U1Xna)AHPoGW|pf(Quz9p<{Qp)Q*p8MY%O>0=zsHd(h>h zfO&sNV3Pls*R+(^4^;#D9K_l8-#Cs}WpmMR@#!$esTbNlb^-<6um_nq z%VP=$O~vUTK1|Jr0npWQin;X_=zJFSG1ijTe8$MxTUK3ZE&?$pw{V{n!0vnJCs26Q z?{Z|78=c@1ACE%27CbEV?q0}TE;9yaOITVk5^)ykk38n&5rG#Sunww3Y&S@bv!aLs zubN=N_u*?}mwPVy3$NYO_F@7T5KX$(QfrP9Utj?hMO|(--Fbuxv4}w1o6@I^xL!byBrf(t7ocVmbxQyns%rJ#&y9fkQnnfBF?0QPf~6PMK1JhcFQ#tf0fQAKZ-^e z>K7XzE~0+oRMu=HxTX?`C+=j9J$)w*?xZezqq_U1ubbN*u0TH5qV5bRf- zKa4|(_!U9JJVjpk1>SV4FLilonlUVo(9mCB+w+*ZS@p#Nwnd2fRcno{kpC-N!!*Pn z2;M(@>bzh>>bSU&M8cD0@s(m8J4^qLmUkooZd1MTG`>aIVX!t#TZxVnVc~D}VmuG}(wRl0-iN zm@SkgJENld+|s)!+llFw*zAj6-Dk;G(_$tm+yMxlqBQB?Lprv3^D(Eqmhnx4xKS*F zp&Zh692=8lT@~-z@s}J1i8$~jPgkx=FN)1-?57$IH$V1EfPg=QSjhRnTuoXor6~v= z-1wMf z4Tv`UpaJAe>k{$Zk4U)trmT09%3fkcTNlnaB+TIx?C(ViI1jWUS_e`$Ndq!VA~EN& zl)dJ76299tR2#%)k9U5#B*C|-OxP%qU;;Yg@{lMoAzNL8$|7Kr{Ub2Dw?;*+>$!q2$0^thx#SOKz7VfVjllEvpb$5zq@@<;;KFb^vlc zcw1lHG|B$xyL9rWquGVu4;QIM6r|dsASd}ARRLIfJyq4Jnl>^Cnoj`6d`JZpgmnpZ z)=oj(!F_+LQ6}@6m6g@&kX~V3pZjrKNZcUdA4tRTp3~uouLOq>ZhHGi<1VX*+ z_kD$oaNrKpM{?;GYm~BinraQHE8edWCuHDDjC!2EFttgn4Tm1rr;}(C54_pi_L}3w z)-zhyr;lh)bP=S!XUB}y)#rhMriz17_A1Buh$M3^!+_Z z%CRP}6NKo-9^m@3JaVIkl^og(nMHQkUyiwud;e?BTPa!cwT#@i#Y~R{9HSh}hB!{T zGZaklbW21Kk6YG0>hjF}YROY?_$hRR`&e7f5fRo`n5l6s#rHe??OnamQ@p3)suLX) zWcQrft0BO9q|SuJ07M_iX^p(AEzMN0ROHN0O|2p;isc($TFwviuV2h_MOqHuPZz*>z1;Tpit1wGALN*9 zFZJ%0vu`^7U?2_E&et6kTn19F(UbZsip`;#nM2~YM;s>qrzovPIU?gcj zeiM({EHY))EG*FrHhWQURrN;tolsubsc_7~LS3K0=fi!cw;uJ$fA0Ue40g>^4x~~s zdIc)bsncHocL&Q1v_dHQo1-NV*cgA4`caf__x6%^g)^bIil=vms3Aol?!) z-K9+uw>taPeX$fegG>u86|l>7I&p74f@YWUjqfn1U97#8`h_dDo8N+W^x-2QY;vE8V0L{oqZ7>Qqp_Nb~sF88o{XYxdnUjafNaZ-IM8`IT) zJfE&KQ1dvRCgjU2r*1LhKVyP{k=e?R4@s@>Mr zxz%D6)V*9mN&v>}Xlyi>XL5WZANR9?%0jNAhWJ+WHpHSM?hUZ-k|7hBqm1@ z<$fJ`(ZX#qHUdTkdf3Mp1$#DB%j72L7IapaEgv4#D1tJT(w1D>QF!GOdwhtsmbBzC z7AAi{^5QQMT;s(25Q&9}@4H)DBLK`w>)XM$BW%8e350I%llFLdb6DQSJ&{PTK5?Kl zG7I;?d?xbCc+6zR6Y7y|g|yp@0|~PgjgU)1J(U;DqFC1s@JX$B71eP`+U?S9!6|Up zbwMb)wvsEo+BP3{D_6?B{EsL>YJS(yw1WknlK6=#rvY>D%?bbCU=Uu$c`+$*6NHDo zbDl~78oqCVOrf(VS^jv}Un4Yq7u0PY;7GkSI>Cz-Sb?~PxNPg0!*IR6*8Quz{q*lB zz@gD^nVB=M3{B$#2IU(^oc8$PJLDb#gV)@DUZq3t(yM=Im{^=OS#jz971wT4oD90m zD~rah$AG)U%&TZ~MM&$kTBP1<*O%xTc+(se7xe6hE1u%&x!tOy*gh%XFX1=Svy>I> z#STExWTHl~EgScQUr%a)+O)14zONsB3VB+4zULFv#ojGj(<%>yjl70*t^sKyoE58> z73{fEe}V%`@WWwT!4?GK+0YGwytj1Vbm|J@Zp_YX$9~r_vo`#M^x}=x!y(=E6cz~u z@Z;~0^>`OQ)hQwNFvZESh!YCDzG7u(?Z&{K%f!Gj#D?&L7u)rqeQC-lam@27hZ9;Z z{s{+JLm_Z)i8k>eWg!+I&_dUA9*B~N`sm=$dcQF2dm_a5JD(8iGjoN#t;YNB zavm1B7&%DM>(s8O1z=Av!CQ~-b@w*ZwkWZ@!h}CcJ@nnY*_<}#@ly1VR1^9xpm?>W zQV0F;Y`i4{*lD0u+&9z!I{-3Eh>3-ZGs)+jA=ov{C*`(D?URQEod;G+Yc~P84SW1_ zd{gx32Q)izwb+jVC8d)N>J{xq@;dN$%mofBxDtn1Myd&~GaG%6JRh*`DYMZ;3y#3Fqfx_PyUsP7wq4MLB?@e4oU3WPQHQIU-ApW9% z|3T|urOh`iSYWg(FH-RGW)VuL1%lw;$|dR9806#sKd#<0sHygg*ABh+B2_?&bdaj> zKtQ^H^dd+TktR)g4OKuzKtu$Dfb<%aUPDn7>Ae$>4npV&r0kRb`+hhxC*LxAGCR3< zW+nSx>-t@R6~%17OS3#{MQo&eZ-jT9BXRnKSj^A zT7e;&R#a^FTU(F9|N920xx5i)(^H9iJXEcGt0#AX`>5z^b?Yo<|IMP|-F08@RITm; zk}XTEdFSLB^)1}XRAy+2$y`?1&@>7)1ku|P=@H#HioJr~o0}Vh9RH#Rg(A!_zV5Sk zvSBKhqbI8yB*i%ia&V)t$M%}~#$?*BnGn^{4hPnsvY@MRSWtJzPy+lWWM@V*9|d!YwmA<%ol1a38~h2;(uql z`8ejaLYBAF-mJfUCwf`kw7$OcurTE_a16y5lJW?YDP|&p2f@-B$1_3r7@Z0V0?>c4v9s0 zQPJJP6U049_S3xO?|Em1N*(|54Q;f}KY7DHLjJ6V<>u%`qMJ2jC6>x(CHOmMQqu-_ zskiwOz74Dv%V`@s0E8${O?FNdg&I5DTn$bO<7)m>^RELVj+ahOmRb5~A*JW?y5A8 zN2+z;7)ftX)eFBFR#p}j&-qp7Q-lg^iffA0NUp!)q6=|;OX)+)Wv}^rAQbq%)#bgv z)P+*Q%5Gtdo*2lpGjqAnT?gL;SrV=PbRQtm3Mc5YmO17hbv99BZe83zSZX^};j~=f z-R*d`gcMfM#8kIEToZ5RC~D$>TBveR@qK*TerM%td`xK9gyeUlB`^?cFXUcXR-}@( zmvV6N=DvLO+UjYD{(|L(^1r#e=*5j)HkF9)zRIK#3A?wP^6l9&n;VYdS<@|JiOgo@ z`U&^b74MzUu9q&;UpoB`Ef~{2$#m5;3dA}2x)iIab zQEzvF$LjpEhtLl{r3b65Y?AMGgitz9=sy=#m@fs)wTU7+CY|;a293N#7({viPflGB|CQDRwS?3Uw;1E$`u2?0 zK=85BtBi6Y>yGw6m`z{sY3$#O!u+(@@M}N)ZfD{ZZuO<&=a4B$1Ep5*jUgU`u}d$ljk(GZm;t@Akl2Ae2tQ@F=kPh(Yee?>U-+f+bExR?oA|e6v9>eE z!1j5A9syo9vHCqougyh6qOYK#e%$zy*!xeIFQVs3-hbcQWIH;bE9-V?ArJjZVlQND#>UzdGjKR<*MAYYt0Ts34e>zIw$%Qy_8cz_rQ6c+VS^7np*_|AA}uQQ9w0va;mut& zihpicijB&Ts-D*KUI5X`_x8wXLQyd$(N*HOt+1&LS}dz-l>u;SzGo#sX6%dBotP3# zFs~xEgW-8VI0BPMZ(myDdGD>AfvN^Kbh-w3k`7K@F1l0W*BSvd8YoFH$)mh)=^~cK zN&{!AS^><`>>CI@<#QOw6ri((Pq35*8L3N4Cn(fx8&%m0YSf0Y-f$VrgG*RXD}5 zUzRT9J0Pjc}5ISW^EP{@(&wVFu3*>ftI+x(u1(D&$ai+s`g z_MdO#+}X$xQ}SLl4;r)}09VqIaxgeP=lQaS*1yKwW!`$Y;}+%NYc$8D)6?f2up}d5 zxqNF~co;o%N01S81ub_}B?G-LL4%v$n)^%a{F;qfetpbf;@adWjbG{0c+?bb@U32q zBMy7Ymf_ptGmUQr3@DovG+HGrVh4HXI==9%^V^l$zwisW#u}Y1(Y4$BVq1d67NH^4 zDpj(BsJ1StO2rSv)|eH|&hP7F^JWFHJ~-9IpH9RKmcu^i+j93hpN_BgN|Z>(tSY0r zN^Sy`W*dttznA>NuBF00UgJ!F#-=@ro!3?ZA;1=~p+}BA-#!;zRwvS-H4a<2>Y1FW zo5rSidfjaN&?Xu(H_`17^>$l5yX!z{*Dd1Pk@IqHUY*lh>Yq*Ax+6xBx>=r%f5jOY zj73A@M7FQ(xtk)&B~RF<3(B17o4tqN9b=qsS@4@HjjAmV$l?W_Sl|l~?Kp2ZK8g>x zzbT$&wC2D*iLD%U&2f>Mxr5*}(fxdHaHEyKEjrNC z44|v1zOxi~v&V0GNS(S85xMT^$El{y^9oCddtn3)R(4--Ok6D%&@uifX^c%Cd~r!# z!*VKf3ayCmKvlPq?d1dnTJ(mv&ylPIHmM)&eIr(5pyn`3iLj1XfoKqsEt@}~Gkyi! zr+N{gU$Qj#D!LT{{}731PCxo;ct4x;P5Q@i(kGDMy9M$0;qPSDO-VphAR!lLv$%~+ zP!JU(#cNBBFgMLa_6&Z~8<6#H#Usau?tvg)DG#aKRXi*lq$|G5kvxe``tkqi4*#$F zFwiwIhKnptnt{ZepY$a0^#o~^?)qZ}eX7BBkaZqyveN7vQD9%p7m5C&%6kk!y8v)< z7k-t8!o#A=5wJB&{mnUo*khpI_jK%M1(0>b?`Y=9D96O5w>b>|OUHY+x#)wa*vkV< zyO{`Hbr|Rio%KdL54vDVr#KEAuy6Qh?$H@qgm<1bKOju8coh2tj;6e|-=NQfoF z&eN}cK(us|O*yLj*_h#`^}Mri{NKeMq?1Xb1aV+pN_YO^xq&?VXPyhmf>E~(vLL4}_nx8@=7w6Ydn(#1VWgugn`z)sub353Kihm=ui)~&5Z}`T0A=en=Qj@jQd94w5yqFUX1uA2j7-}sJwLIx z85BqcAqOtR_}6S7w$U(Q{A$V9H1=A1yQcy#I(i z?bnvVx08;2cgxGP-h{oTI-0Q_`>_ba4siD!7Mkkc@?=SwXvpts`FaDqd`A9g6+^n~ zk8fhX{<`L8zCTh>7WpB<^NpZGl1YjMtqt-QzynlJG~YwT^(^BcH<2fd2}4(V#X$me zY4X&aDSlcF))}TF(}folXCi|9?UwxfW5O$vb}p3pq)c7+A{3= zB*!i%*xL_T1X3gWbo0poSeLGM;{|{{Ns6U69~_sS$j|RW%=d(z^NR}y_Zg2mctd`* z$FCuHPaOD=$IqF6xUBg+yLq8|pR+th`)i(` ze-In6JjoNihS7vp!=LhUyk3iAysQ|%&y*7d6|(B<<~})loaDXE=GlN-bMs)xeIpIX zY>R(oMS*X;d-}T=cy$1+_`R57>;o>RalrhGfrkL{6|=Wm16e3$_3A@Jg5HHW&mk`k zWwb4WH;HuOdr*{Z##H>VvSon#E`$(758=R462<=JXCPq$1%PrX64{?q8jVmiy0Tzw zu*3-tvBtdJ$XO)1FXd*9w1Ef7P7z{jBrT# z4x7ZGaMpJaI$7m(2at{cM1~62r37aaWnWPgLzkYU$u6`xL-9#46}x zwK+a4s<7bAhJbt;)3y**mtyB0ZL2_wW#Z72?sptM1PRcmsQ;*5cGNkdw=unbh0D<^ zsa-UAX0*$+N6%KeSy4fcFZs5Kh%-yUwYMv!f)8<;VEMD7r2%(~=k1VHF9p-PNss$A zb08Ru4aR*Ej$ikK0kmVPH~<0}x|&cZeU`>QZ5$WQnDlKv{5=tk9h=Ws$JK6s4GJk& z*wR31wYEcb?}51&$WhwK|3RmdP!L5RX@rxf`cvUcrp&)nJR%_?0v^zp|ZHfD=xqV8$dS>QM1sjci%S6NKc+>sM@HvdcXh0`rfLJmRv zP6&O`NP@PQ$ofO&2dCm2p~H+LeA7{23gunM3~C%vhW&2dfTL%aKFADM`jOr|cSzFk z^Td^W`l(}^>$97+R*l;(`=yl#qWw?L?XV)e0=rKFH!hcbrf16MHWJwg3Te8biYOm2 z9`A8I(fUk6a$9C4Co@8A$_H5XBSd1vtoPHJii&D)8ky^mkQOZe+bq*Cm3v0C^I99( zI>+G@w10@&w&|kJ*p^_QB@z6rey$>|Caq^c;hx9T6aA0M#{6_hWb=wYI#S3!hF|&q zo01s}$)uWPhcHUTEI@}r4PWCs;2|!oV($rtz!1T9NV|aR>3N6YO1w`b1b+B_`s)0f zDc@Nt*PZEzSLH>N;4M3`(Jw^#!H&4~smuP3&fW79`0GmJ`)a$gkTm4f0i1#ay18Rs z?*7h(_8=|V=3WbsZ)eM0uFE!z77b|vtuIP0Lo4yH9WX0Fa6 zH>b16>x7#%`qxqp>G`b}@oNtZ*>3Gp9M5S6-dm{bI8thrAq1w?dp;3Y(9@RW(CPC{ zgz%QeFn0Pxl7DrcQC7=p48QNv!EV9QWKb?|cAekL5GG5~_m_SU zJ8Qj5?%KYHKU7Df~dJ1`lDbZ&~)^m8-$1*iT6rRj}2RDYyIjP74)qz2 zNm7=mXdgOecw5N|t7;>fkhFazG8jMdbXIkqU(#?``_r(k%APPMPCqEW=&-PBQdB&Q zWa9Rr(^L3SF*rM#Xj#GpUTZzJ5B=h|g?SYJPgIg4l}od^(V^U(Sfw#lmNb2`tMJd; z4Vi7UL5i&6SYGlw;kd`2LTVa(kM4BOeQT`oo)AkC5T(+sMY7JAg$o>A3N~7l6YaZ{ zrkY*usSIZ&d10ItLS;I&o4abIEJFdDbo_7i{t7yW$54Y9>b>bD)^>wd1Ogx2mho`a zu^pca#Y1k&zTTfi;4TONg@^QjCp9P6oi)3xBFfILzT72$cfrNC6Zj5% znOLQA+14IU`xiY7$E>8Qeg3P6tM4NsiXzlH+EsV*3hb@$o{j_ur+J|K{&8%7-_}jh z2j*?k&DSZ+F$(WAtK20AlOJNDt;PPB3dMH4+artAjgKIMPU3?N$+TzPk9&c2gF4rv z-DXId1cT1h%b`Kiyq6%pGwb=f_?ot!(5I6xdElEV?vsvnpExBc8ehagO4*;50cP?+ zisO+CkL|uxo5`|w^VF)mzK_HnC0qtYYZ(98!R+lLsqmL@)UNaE|J9zQa>X<<4OZ{h zO9{6-cfHp|q#LVvUX{gXX}3r4n!i8Xy`@t0^?mXYcot>%jkQkh7d<{v%BuoZln7yFn{xN z<^w=5kWI^z#%ul66xHyE<8Ak)H`34hB zCw1j%>&S$Pe3Bk9dk-v@`0;iV45|Ae*l~Wa*@ris)l<&o0_gcuru6p*renN%5vH~w zm-G(wwC4uCMiV}8vECeAn9QM(?7nFnw)h9E^-fHE`I*o^@w<^p`37L%(r`52Z#MyT z^^fjZLe(Mg`~82(Q&+g_-{P%LR{gMrbxDj4pG6cjrk$giddev7w7(hYohsz?2NRp5 z&E)4UCW6xeHAc-V>1(*n*h*WF<8#mi7b$G!`e<>)=Iayxf%r4OVV&rlNqdo@H$c4)y!bN^>IB2dd7G7z3?UQ9 zKNo#|Dz2b$rC&7l{a#g9CDE;;K&9aqv9a@2o2?XB5t*LjZ;6|yjC_{?81a9(lZePz za|)MGW14?=aeAD6_{n5i0eMS;K3!F#dzmfcKQbL_^|yX;Z>VbY zytmU~T{>gnLd5k{e8)k-TFPyI*s|KavJS`HqUR|;@LE@p+}t`nZf|b6361c|e)! z^Mgxc^Kb5k@j0>YA?Z1W^``4R#6_$cS!XA?C6^a(P|MRXYalldm> zfr`J@x|X%QBCkU}vEB0JX;`!}f$m=aq8+vFOSuHb5r=YeD3l|`Sa8 z?$y%07>%s+qkSI8p?>mk1bzOa`;X`kL-RsmqTh?s8kmo#1ij@K!-d&Sb~Hk|@>5|e zRvwi{2Q-kv?A=UnL)5QfgU-uj($2qwiPsvwRZr`QJKBTAqJj39-;-yD|D%dRntrpf z(*F}jY>y$cwZ>X}F^Fp??!XWLf6ru#{44Y_q1*584e$KMyeE)It?(!7OA%~_=w|e9 zvgVP~}5AyK>-KuC$`(8T0>tgxT4-pLJztq97$GF5DLU`8tH;0+`{{01ZqErxn(WjJ~x7w25b zr}crX{o;2jBWA)}0t+g(4`cW)VFR0P@%PpPSYsclyy_%+W|iRMCFN=yr>Pj47R3^o z`CseXjd$9&ch3d~V*t@Lv2Kx;!y>L%H7h5|uov?z2ZSr4g$aT_a(86A_2r#2;(k;( z#H=UXzEBz9O9ui>V&1N(!giaE#4&^HG3Q-qaMr4>-qzro@93lx@!Z3e#<#w8k}_0p z{%SM2=6pMlx0?{Jc=8E>@0I|Q5oNm#`2nR& zyMo~lp~%=n9=yE?vJ9#;#Au}^_3fc174cUMo4HCb^2g%U{YU^=H1vya;s!7US`Xjj zBXAM;u;|6lI%!NqaGYZm{BiRbudJJdeTL=zFys&jUOx-jdkwoe5U;Nqz2CNLNS3Vc z)hOxHQN2PYN^j~)Rz)#iPR7@af2B&=K#6l}hQM!_41AwdKzg^{d7llJUB@E$`l`y# z3%kUH6!$o5AgE9r@SsR}V(3_Wn1aUp$zQ$+L_jFUdT;Q%IDsr64D=Lt(te&J*KixaW3 z;7fvCrMqz&0zETNla~JX+(}AbiCcUb|B>{}aCsAbKfhW2^qADD1Z&52{h|pbB_FDB zDKzs1eC>LxZGroAYwA`{A2`4jT*aM&pREx+oK`&4kvKB;+Ti#cE2})u#78RY0$Gar znDImA_iXG#z?VrGCf%=1>PB6F`xHC>lPYp2FoXm(i&CTw%KWH?IDFt05;_WaEB6uj zd_~#}Pa-v)HXKAFcwcGh@w&P~=3y=hc)u*1R@aH<7{3bvM=-!kdx2|o=)IwHE6DLK z$8bPg4-@HOWwsiCi`^1aA>5xOBu8hNL|p1+|JwJCgoVF$PhdrCaNw+i5eqa7? z_2{de^Kqw!C0TC2?q7RL_YfZ#dF!#p6)$&PoEHR^U~LayED~1~8Sf686P(2hgEB>t z);%{FesjkR=x&MR$@w-Ksh@;aSPKc8vK}G9hzx*WRlgK;&zC52@;L4KTWrp>p3yMv_ ze+z5QwpOoB9V+RJM`Ic@#Qh0Y2TGG8q?IMWS^G(M2d>hSW6ZvLBDrb&d+N(fzQzf} zQlOYbKU3|UUZHBosr17l5I-^t8NDujW94{t_39vdHTV}%xs8E7gU~b!z&C>_xW&5nqdpySIW^e=2CvTMz4b%U|#p|WVk~o z&R0zp^+7)T&C(Az$x#Stv@gTee9GfbNYdM#oLnDn&*KyB@=TE0&%H&LB(D&wT=SmI(J6MQ~%v16#?^Akxsq4zD@Km5;E=j7`i*dsKrxaw>9b;DD zXV-n!O8Q%Zx* zA$BM7mfMgdy0G8wg5huW7>k6zi;FiAv?3H?qA7o`<#P;okxBpG`>5g1F`Z38WF6;I zIY1r7JNMJaigvH7aCt{XUE(611{D56y5b8dPr3_f zP-wz}kz5;uLzGhlDR}dbIT+qC!Q;65!P^%nVBWGEUih7t0>PyM+ddMe#D+ioMhMu%{P{y`bK?G)v%}E? z8?Q;{p!Z-v9FBy)uTa)of{NW~lFVP-IHsYen*SAzzrrY#&cfgl_za!&+J*B)$4cok>&O}Usg2W}y|NopM zwK$81ht>(P;7Lg*Up}pM$|5@~O1wnw4!PW}=K1sIy0mU%vWd?`Kx_^a*G>Df>Mv6= z<2eIM?Y3on+k_@FsQAhk9~tm;o2km*oHO=5+i5`9e1Ao$u&9VfMkJ+}f8j%Bs|z#R zBC(Gz$D`N@=pZv%3%35KUkxo#H1nhw-H_a1cZZa-Ng35oto4Q8As@uVX)kn#6${cr zW;Sx){=x4~?6id^cqnNT_!&nY%~o{_$Xmls*?$xOk=)G z2u!ZGIx=1kxUY9&&SKCyqz{NRbJv`@1RMaT^#r;?zF&R(0aJ`1>UaUrU0b3r-vu6M z9Zxv3zfczSS?ayRK1qDUNv$7u73a4I2{qMbSw+ZT5?$bAM+z=fv#}y3LNDuMK-#-M z(qLLSaY!Vh?TA6#l zT)B?FioZ%Uon8;Cw5XK8fMakj7yi-ZD0A!WFu<>cpw2>)g3JaV$`iXeLMelQLeB9% z$avy|g~x2R3wk=Ie7qV}muQ>n6LPf9E9=q&rsFUovKpVek*E!jXuO=1rw!rdKE`3uCuYi_b}Qa z5mYr?ch9K+;(_?w+_XhDTOi*nLLU zx9k2={hFL%JR;0aByk*7q46NmL{Le(%jY9hPx$$X!6}CxHTDU-_Z~e>9Oyp(!~bY$ zdW&%I_&fhddw!GFwE8vO^#dP|c3qp*02 zSO*0+^|T^T|NDRhJXHXlo)jb9^-%gmYk+@{sG2M>*PmBI`098b3u@+=%$9=8gAWJd z9o9w6GUrMyfL>6KJ+51P~?xZ#v^*b!Ij5nyE%UR_@UvLjZq==(wkzUZZ! z`RH4NJOTBjysoawpD)FUGWJnM(LH5MvHLL$Ce54id>jX&VH<`SUT>gH{;IQ~y%|u5 z>I=%00&lx=@)2e>>vE&QDrZgN_(Qv*hvbZD%uREQXl(*eDmeZNnu503$QqWyJf-9} z)%cv6HPz2{c_+Be)BhP6si_&}DPvp`;}C(kWSVQds}W1~K7{7{dW{8bvz}Upal#RA zd_1%}B*m>FTTCq5xcp5S*Cu%SikyW=ZjYRExmv(fcrT|d`V z3TAr~cH^V4HvLet9K>f9|J&0?Bp^}3)p>6Mzd@45sTrps3eE>dPb&IeGYw5qE6tIR zG7{;~8q^c$QxI!lpKlDF6|Lxf;nd$pd&^Yxg%eHUL$DL!#OZ`D3266612|I9uJ6)>i2)Glgrl=5^Z9b zeiIIw03^O4+~b2$d^R!h1sha9Xjyw7Up_%xM+p#9>J5mjV!mZoLC^vu)B+Yo8pOX; z$#Mw_wchyPy*uO?6H>9BV&A?*$AA=cb)UOE#$oz-nAi|1F)4%*_v7Xx;Xd-Ie#N(5 zjJ=k-36fYi`CWpqB;Nn-ZT=Suw_^1=1{yLTD^I5PxR*;LJMe~P?%C;3JooA>ial#K zN(XA+jzM_xjZ)O%fi;TjdI16oA7u9ZQXYUb$lkdYto>x{WNBRoJ=Z8cOfNlj>MR9I zVJ`Us{+>b+NVjgmW_REF>!7TjV=R#2NHHAj8{0DnPePnPS-iF;>n$v8prgI*SaZjD*?Jpo#IXGv7ANjj_nGTHt z_|YEY%gc%)Ve;NDO6HnfMci$C`V5!2gTMxVWs~bdm#dT1{0C~f19xOR9N6d$jCXBV z=4Werr=CUnTaYwTP{?vz(&R;6P}UEkM_TpmW2rm>%%7AKmF#`ET>tiUF}RQbASwKQ zC0iKuj}^-lci%&VUPU+P!S$`jrO)9l63&xC!3~i1!&aqRd(0bu=!&Fv4~f4ixt_@q zc$H<5@ZFt!_}ue3Uhn5bhXsz!aaN3}{@RN+-XKM?$q-mAQ)%jG$1#1_Z&fHm`BMXD z^Jt}l_`(mQMUNTHii(*3tML>-N0zOu5h-;G<~SsfYfq;ta$yiIJTR?p(OSw=-i%Eo zc*n|C(k1s|m+PO8P2@|bn{Pif?~ytcg0c|VZ5%k>{n^`r=e?xQW`E*OI3J0XjLSxO z<7fW)RTJAYU3T9T#DkN2KPxbkpAAKT3g-LPLLK@dyeRYMPdYWbC3a7>lKXvGzKh{3 z4g+3dWultDo#i&?dnF&8qc^3KT_WPW#}Sj5x28l9ZuvDdtZ9m#H6+uM+(7^?sG%GP z2cus8V?O>#r2oIAyF)(!kH|GQ85sTMFG_x1qhE2Xc(yD>v zhq6b16a75!>f+hJs}3SGRjq(Qyc!)ob|l6L5eq!aq*EIjom;24nzU2H@w?~U+B0S^ zs^se2v%Ty#yPwNcw|3rOQA;@CPcCYj*W`&lv>Z-#fHi51PFluo@@vqvgA|URXDVH} zXZk8~Ld3u99$*%O-r2)=ICX)RxC;L@m{T0DQ<5A0@;5*2RtR1PnMb~|!i`aWc|=F~ zrv^)2&0m^*;eJ$`J(1!u&+nM56&CxH+rSX;{J&KRx8MQ*gjt{ku>M1q)%S=f5U1UNF5+w10 zcv_RA?=#|b8Vj{HcbD{lqY2>?1RVR+t0cK(+?R6~=Y(Y#`;5ftw`k#1-;^8>Qo=NM z6qz*dPxILH`r^c-t1;U(vYG2qRd1L96_D4SMQTo zTz1tVGu>Mq{$^C~=gl+h$G>(c)xSB;^1;!!-D~C^j1Y2St&6wQoRNv*`(gd66~Y+ipS(}opCe|!OfkEx%x(Ep(eg#k=260n7B4>?m9)fKc7Oc~*q zRywY{l(*|&#W@xFGbt{kZM^<^y`x(dgp6;JSeo;N8}$%6R$_9`J`e)_t!q3%@ArphZB{E!T@bjY!Fcs=wVk=>)0#sh&PR zf_nal1Z8o{R57?(EK$U)v#IKnr^N2m|9q%5a!i=bflS%(R*~dI&dgu+o-4NtLXZOH zYS)`k%(tKbyaKRR|5~6@JlaEo^Tq}OJ40tyU-?yNNyBviZS&&_jp~`N=<445RyV0I zll0GM>nqoKU#CLW_)cBoi?1FdtoPvMM;B*ky+cJWzv2x}Zt{r)ZF`}KL%biFaZX=+ zu1OH9c~SM&+XeZ%!kwOqb?LzvNbixoR2^_1RK7YV2_A0hm#xzG9h&dq6_@l7>fsF+ffA}%2lXM2q~?={k%N60thKqXJ3@R8CfqAJAe;zB0M7dwqT z6%c$Av&H-aLn1>oI`iWD1Y7;81 zy?>wYoDKXpH!oMXIhhYRSmFtbt!3#JM#Fu=`BWeWgyz%+KH$y`nee;04`zE;vXQ|! zrOu0ry0WR=9AN1cu|J#MwhrrWO1?eB4anZA-a84Ibb9l@$7Ss2*xWC12PfXcH(0Um z)m4Hd_Pz4@Yh)TX;os0tCh<8cm0DNyz5iHu-roMbm^DqNLCn?sfWiRi^V;N2p+90b z+t6zl-Src>gNl`w=?ViHl-$aP@GVQypD{yVHChFC5cmUf&bC81 zd`rGFk`H(^n-zKv9xviF8A)%iI$j)u;$zZ+SD0mhq8A_{qWfVWQqa2ObEc*Z(_#`; zB3GR#AK7S=jya;evSiV+*p}?mra5zD)sJ$G3&YaxYaqU9wtLLCrtBl}B~t3()+nO! zoYHiUcCrUMGat921NUE22!(+R^OZj|JHd}?6swT`sZje?zCGQkjWOIsBLC?4O`h!H zu{mrDbF)fuTOgVf-xY+olsbw(D<;5J}?!FS^)9W(P z4DuC za$xIXHGo3Wv$$QFp>yXOv}En`yOKHmC=*=bQB;d9)d^A*W^i{sXN!Uf=|`8 z#r=AP#44U)OO~sF;9>Qj>uw)z4OzBF7^;_voA*)Xd|IkQShxp#nFWb+2I?sPmwerD-@-RCl zx)d<9zdQ;N?fn>MzLXQ}Lj;En@lB@#p&k2eelS;{kZ=0ogvX}7if}_)<_sB-#>S9C z6}B0f@`k$#=gB^BF>;f0#ngf_?1#f`vnhh?sj?5Ya;%R}@UP6pSNDuUXg41b_}i}2c(^TiZU^|AdTfDm z{r8z6>$vaHtrRY^uX5?+q^*^|tC`}#an)Y(J@N4-ZM{E8Z#o&e>kaj_%aoEYJ02+0 zO~1!V5!>D#yj@M%jl2*;1r5IGmE}_FpO5?WF3S*nCO|WUA#<$XsGzFI9|Qusb%L1&f+DhPH0?`rcK_=;16iJA9>EQr96WTye^Si zjvd)Vfnd!%M55i}A%Xcz@_IWT)0OLf+#~ERyLGm=roX!&hfl;5Y(IKpa(7HXs+|`` zh$Li46002LQ$qi#FXfy29NRX@+Q%}dQ^1E)X7=DUw+UYrDb9EK(R6IkHhaDEmOdi1 zq8`Hal`U}WJ*$h@Cnrbd9}=of(sk%N_RB`im}ByeGdXR{ja`YmgG$DxQ}qdvGK+mA)PRee3431t`Q;K|cR@D=57Emv~uh#}ntTx;+ea zLUS$U*v+Nf_8ma&>Z*E5>P1^fw?4q9syMtC-93F4299X#Fi&(n_)@`5}qWMQUZ!UnzL%Wk*-MCUw+9T7qmXm^_enpJbl#;3usLzR2Zo8dh!Ay{H3VU@<;Ex@=ED)H zknp>0U>Tr0m{PYEob-DXYMa(YwZadc`;vTO<-J(8z+JSocgg~$@Ic&(@x^{>*b=4x>p5>A&3B^v~VMi9L!Gng}9hQ@k?4 zfqc|TvYr0XsFgOWjN}0{G^R@?8wSpQqbsTA^f^)-pC8Y_l@8TKBUKFG4BErckZnwUitb zDA&CtIndf#wnvjO66#-bD3p171Q=%NfN-V^^0doG!x=3<{qQ4?kDz6-I{cTmA3}-z z>}p_FXn~RzZ(r07apv>KR}MFw4}TO@?PJWO&M^Rr#2yZFx;d-(6c(bmY5-26~lFXseT`pz|cjeA}+ho>z(O|nxhNA2HiT52RQM3Ogl>oq$MCf&54K!6Mmu(tnFT(10S_4_4 zDC~L6)ARARXW{ev#yk(z_(OeoG9B=BOU@B^lX)4YvGN(N@tV+uUIx?{#UROpl42nW zm1`J9?90R8)t9*e<+6;N-^+>b!j_{;NfQ}d%z~sc?;HWhtGpL&0OayF_~(Fm2DD6X ztyv!??S4SGNqkh`y9%E6jtS53dy}2=8^Oxt=0|_8j$AEj65m&h2!n1<2z^-GD0F_l zb9t*3KI`nU)V1~=sMdaEBKcKa3#`X``G2@N>!>K+fNj&=At_yg2nYyDEFB`DA|X;D z(j}csBO)b8N`o{D(z%4v-Mt_PNG(V#Y|XpB_xsNI&iVeEvuDrjv$Jz%o@e&HuS;C$ zspWkvKoI)tY9VN`!}zfCGUp;yiaVr;b}%Ce<#NyjExB5JofcXRzCB}=#ABke!J{L@ zONbG^`La&qZB3TenT!7V4*?WdLOsc4&mU|Wek{s1`ITgVU_P0CChASZr%S!QQjA?& z10=(`XTv%y2?8}N^+Hc=wf(q0%=cvBM@2mkMPAR4?%ZusmXv`M)CYf+T_}7~ z+~{8m^q0<_bgmBR;7}SXfpgRV{+NZU6_g<< ziiO=#HNDn0O=IWnhUaoj%k zfDIhdegiyf8{;VGB30ARru=Z=Xnxy3RbG3$5WJ5rFyZ~lHYziw`QouEhYhWk*U|Wm z-F|zio!AV9&AY^xFmKMLf*kDAHFRetAv<$Azbt`?li7Ws?v7&k?SpY1cMx0G2852P zdNcusu?58lc`&b2YJMG3zy{?yq`n3LV$VL?_Lpe~bh>*S$%vDjUm62gtmxP2+r^VL z+FismUb6ba!kdQg(uI~S`PMTEH@Q+S!O3Do%C^dXZpGLRbE87opc!L-h5U`d*{T(0 z=7vjT;bdO<(fWfqvFpyJI!~jMB=YCBG^MN=S91uYPen!GBIj?PxVSULt3DU`)4dK| zYS)T+_;q6u`WhaG(e?=D5IptDIsKgR$>`c|ssGJ&PV+3Zkq}f>2^@>z>cXZkrjXz|;G_Ipf?a3S zS8rS@Sj6)S8X>#m`up=n)4Mn88~;N8xqLfw^>TN?O6*W>*$~7M;Yr#SW{BW_!n3_s z7%nXJ8(+Gwk3@PgyeL*plB-*i)QISzLpWVbE}oap6&Darv-MT?*UN;fV1PB+nzdz6 z{5-++E$si_Sl*_+q)1}^PDfz(g=jdh({zg>CNJWd;+p9Vt4bjo83i6uaX9`1#uhq8 zx@7Wvvv7XJ`vFp$bYzhM#*D-Qy6&gxA-;zNz!8hq3C2OG<>i$O_H0(H1?>P>uLAfw z21{kwhy-?rrPl;EA0_tmwvLU+ou(l*Jg+J7BOt33ePjaH1952A|E$F7*OoC+@Ee?@Z9b z7M+&hEXuuuw^Dz^0k*jHy;Sc$rG-~(wx6*EraIEZwlaPd{Bay=S?%OYp;FoT zwwJO{xEj1)l)Ln3jKB^4#j0S5?d1}c<(pfv`ug6it%`46Hr@CYdf1nTEoJ+;OPHIY_z;6G-EdoY28f$X&rx6>E7z@@!_0k`x zat_#}2+j8Wek?)Qa_l@7eG<-kk)YQEkgNpjW8wI7X|L;GKv$cWj&;!Lu}maRphdOm z`1xVCvPU@(B}H1I2`s?{VK^rND4oNA@1S<`#c^Ax*_Y zXw)?-w?zPHs{yUQ%h_JC3aj(=Q~QJ~qX9M)HP9)@IItK-$GLbPIH9 z)`jii^TfYjD~$vGrtu32^xTM&!0l%buccdV480O-DOF!iI<5t)Z$g=Yf}#p-=2AK9 zS@MulIwTxXaPAsh;t}FFZt_no_{Mhj?Jm!p_1E$KQ7qw|KShW|1AoaJ+DnqPIru+W zdz!Qs^~?5Wv~PUCZ}(h+eBIprw-2$@`5DtqoP;kNN?8w+cdNh5btvo{^}nveLoV#p z`i386xPZ!1Wf#@mT3Y(mI7J0?#oV<8-cZw{G_U;<4kW;X%fBZdKoc^I>(VWpOPxsn z1qX}A6@DJjpndnrh}R<0eaciFj8Ewtb#3U&Qn6(VftbbA-@f~CaaM?#0iunfwQf0k zG(>w1^9|9RMuM&YswcPLzgi#lNQj4rVp>#PPlgvC!hButU!N1g>c$^PZ@8w68QHuZ zihByGLMjtmR945>YHo2iNb>c__2PB=zJW~{F4u?C6t*Yp4KeOcjGH6k!I?ORKuNPN<*;m8R1 z>ndzfV?tkFzF4$hR5I{%Z^?3C6H-bMRr=8UQF^YZMAvf1$8jNci1E*q<|%Sy4uI z9>x6bUT2;S(3|{b-c^FDbLDGPNVz4H&NT~2sEOr#{Rs$temuT2)%TkZ)01+g`S+w& z`KouOPCx}Nc6{&fdBW;tU94X|0N12-qJ1#V(_Dt2F!Gp3pP zwGn0qNovWhf73PyoDI#;yUVriW}0j>?)({VSv)?Y$nhn&brtKi7!~+UEq=9ISd9gI ztN7Xn31m;zV=L~rd?Bx|NRhsPxxo;RU}xN)uqcIY!!86CipJ<}1p>l{oVJ&Awq`S| z?z#qqHnmCVWOIRT-^gJ4+RJ3dsX(rnB*?|sVN+Y*j-%tqy7i)Yhlr2J!8<1f&nUN+-#k zzx&Skms*G?oafruDEn=Rgon}700zL5c`qVESKryBc3}Oulhp?gzM@O!j5TA zzVRG?L>7czIiEW33BBVZ)Q6B3MDc?5eS4~ZErI+BTJt`I zKT|9yCFh@3@ti)G<3Eh&*jRN#c=11xs{6S?d9(9lU5!_B0w0BnmV5!1&e?JM=+8si z&rZlnQraa&{Sps#)}A$8Ow<`Sy&=?#_b|=aT2Jnp^aR`=Y7=0!!oszEky;Y><7~jQ zB2D=)wIs~H4XC&-R=C|*x?-Ahf}<{gckhI|Dvkimj-41v1#1|Z{)JfUGdtf#&h*Tk zGgiwmWi?wB+R*3H4s)6vc^Ym|`!LZcPR@TPGz;=~qC&SEzZZEEIvhZqbiA{8+rUT4 z(_Dz*YR@SLjEe5VEOY@@i^bC%?082(*r1{OYM!C`RmKtWa21zEisx0T%``>5r(p6m zDY(FHlxMhHk5s*2d=c7(GaagYdU@Y*hCP|}86Mz#J7W;x7WLS*gUWv-M86w4sB-pc zm6S2`fproV({%HTu$h#>N&S12b^a+A;7*A@@@EzLiReKwKm0%>w)JZ6)B*16jT!l! zr7^@%5;0HlqI$#AFH>z*QgBD~_G;`0|=-LD&;+``~r8hs7 z7Qs(oL*>o^PPLp2mO8!4-#7ofM76e!(>8{G{x<6^j3G)lc?fBKPeWAgjj)az9o)T- z2NRrgu<8h@|8{;9EU+RQ*HSj>ppJ;yxYf;eDb=$p&I^fph zPtvH{&=x0Z#fxSIk2cy3%24Xzy>3=deV!K`!9d5(AR|R8R+~2_XV+} zm_NL=V138(a(tKxY#5$q&Tqk>r}DdDU^RKRxTa*#9CeJV3(?e9q(1glc3{)DfbQmi z7cOz_*-x;w6UZWL@-ioIZzWf0mG*@9vyO&ibfZ?UBmRW1)6<3+Nws!;l9kXyZEnyX z*=7gujYJE{!QHa=q@QbDRwCAH@lB^&3y8k{f;Bn$hpp?W47%%qI*6<8k2?v&3JpFz zRYL&7&n?(BVFFjf-TlSh_@A9iz#d3(l>v>~OV`$U8IOR+AdhmK{mbHBdWj}qd$3W( zSlw({mkB4CT+%MNo^5H$v!f6GLTyybDo`5+Pt|wM0`=Elo<9HV`|GN&&I)^drX;?> zc{(pJHYQ+Ft4;Svr6OI@iRd>Pz{0KD>^jTok-MQ-mn!J7pUOLAJIm5Z}a6BXZ8Iu)l=?R8NFF$Cqk0j zo2oDcBCa(!9~cOKWnulP^_WJrv_jh(uji)FA_dw)THJq7(L zylh3;m)bd2++_OkUq1=2wX4MTbYG)?XzDHgzn&{oYGgwhwRJJNuL!DyaDl ziDc^JzA`9kbU1f@MG_g4+YdgmtLqL9lh~e$F(i6iJSg$jL{)05f`GAYgngN(G5v}9 z$(EI~Wt`o9J@zlT{+*(a9pPD|B6vhZ>b-b8gm~NvIAV?! zS+OHCnBXGivN7f@pu%U$r+GV~D2f$VGx^8g(_ieXagL7H)de4ot!sRoLW4|%qkEsR z?|wIQ;ef!ie6kB3_`#)TfI>`lYddHW!vV({y2s%r3-S%cBp6Ho+e$kxd~SsyB-&7t#Lm}|@g71_lqZL$W zuUkR@`RY$8iUpfr$(}xPx}~HzdD-m+&K6IsYgKIOxi^v*qfuu5LMH4sOfC3zqO%eh z`RnPe9QYCG<#8=FT?F%hC=b(y#x`*Du#lZD$a|SY+Y09@8!y+(w-6 zZ+E?g5^PGzFd0t2e3(t2B+a1?9-p#A*;qj0-GR{X-^xaR@1v5Z^jCAQQ;c<97TRwi zjjM92p=-pCND(&Dwa(>;9Ux&5`2`IH_iSAFo{a6_RDe$Lr3_5>b>>tzn=s*vDIPi= zi6$2B9jjQL*#@P9fMdfdcuLzpiX6E1EK?Lcg!)*fl(Z(st;?4CYxfly?`1!dHj4lF z_Uz6Zh>m@yF>E6yVTJEG93bAWnEq}J+!C!0By@ORv3nwp)4*U&)t4Wwef$fIV52=B z0NS;x!bd{~&T&IHtV~O&^fmm12W!|hT&hMq_)%j_rZ5-wisU&e9Y9^uhrwO-yy8ob zC4tkm1l}FDq%7Gh&$Bj#ah9tj>%?|>4fXn%xW$wn(5~Q}mFgpK#0Ld0_fQ*O7twtH z-S7QuYGR-L7C_%VM*WcjzX)n4=+)KVeldCe z+Y}4C$dmkW^pCKJ7zJ>$hNW8{cX!=DB~eE2qIe^4m{UWdY8(*EAz|z*Qs@%OUs;Yh_{#K9Z3wt})!gcCANm@Y4MA7VROtLTG zy#MbjgOs$i@n3$mwz#R+AJMy}MPby^>NgJ|)+#v9Mh12%Cj69Jcp=-#PQy&-Hb2EQ zv;G%k<|Y+w_{an`_8#g*7xGX`i_*W+!GPV)?j;RK9X_E_@%(( zO^OUKtT!d!EerroQ-jw|Ze5HbC+QQm2~GW%03g6o`|dVpTTPq~#jt|tFp&5Q73 zl@WJ$ZVGKZP(YE-v8GOE#Z07wBt_`D>{Z|*I5+8^nq9!#6dBZh7}JDGc-|lI1y9QE zJJhoVJU-Y!`zZ0hgWE|&y_Ek{3=DiqR_!=XTVBemN9bfU_;S2nZ30jCSQNw6X-O>B>oEc>25xGm;2Kq$hHuncZN zU~7IT!rvoKX=f^?l}~=jB-wy}Ko&#t3~v@7(z>!TSc)0(ScMQg)N{dRqh+OHf<)C~=KM?@x?YE6R6w=D&Fi=YJ-Co%Tl95&IY%p8Y2kg=g^gJveQwzSi z-?#ucrIPMwIKJsC6r629s8<s|Vv!YRw;PePNY_dWmF`;GC%5YJ^XCCf`hNZ}2bx414nDE(Mx&5IVW)CFD0r(65HJEuZu#1sw zrDWaF`gs1mD4LG`aVN>*l<{1?lLop?gyvucvyknfD#5d^I%;X_uFNa_hb&LRJ0} z{03#9FRA{<>m%;TJ>$~0_xlwa)V!}}He&h37qX1A_wKl4Rl{qJX+E+?mCqfo4^zsDd6k^(Y zJHhfE`(@P^Rthw14*2Xlx8j-;0Z%z>U#41UhcYAL6fLJF;V)!vv>KLaltp1FtnI$;NA|a`~GQG2!~DMbD7Ke`aVlnpyV2tiC>5@ECP``55KxP2=zGL{I9#uA!WXH0bGj2I2h=}g$J3H}fisAIl6ZOBJWWgRb@tyC$&bH@0 zzN+lD-DVGpp`UFIF-Hiz;owEcI<=^qlcWnr`KDaZj2!+yH}LDF=q%qDjb-+G+AJ0~ z-%jQ3)ER&wG^lEfZ}nFo1Oi#Z;Fh0Wu{8;6aL#{{X$Pcb;aLmG8+6=+eN`|NMMHaui z6;vnm92l;wz3~UqDijt2fjvR?&#+rR6mpq+b~-$325&IbMPqGs`_YzY(_ACCVEq1P zg6c(N)f*QHvTi+;fFEri)$Mk)`->1-)5!Rc+`~tBJaUdGLXUs96>j#}Bhk=KOkv1J zbKfy}@*-e6~Hn$@)c3bNBy8N>n- zQ8zD&|8e-$Hq@LptJBQhb#;v6;SdiW8biD;nXO+r(0VUV_@U*!2wNttp^h2x7t-*< z5OjEue+1`z|LjL<>8@)dSZ4^O&$n+!khN(a>qjvg!5xkU(ryc#``pjv;i0qW4gi3p z>~XyyJ1-c}=87Cl7+b~2RsQ^Q%Wer^|9AaKecH#w$WtLgztlo{;c}YZlYDb$(<2HC z)Rv>BF!KHAzg$aykWiq5wZP}IAG3%m=>0nCZ{KaMX)rckATt<@(9?2xKc4giimG2K zYRAT3Uyf{>WzN2qdsU-=wTiFPo=84C7yJ9x$FkGOBW!IQ5e9gI{OIrrm{+YLk<=_h zmUnyNf?AF+03>uZ^k)w25vu$&kz@9AfJr|4-7JrM2XHI;b(R{=a+uZRevaKU|brM~j16cO}>%p+{N@UN)7v-QOc%l!=gQy!_^qOt!;JghT z9^$l|Y}{eH=n}n(QKlzMQb0Y8!t7BPnR2Wwn}3p-S~kTbWVlbsDI5V?_k>;jqtJKU zUZsdp*AYLh3kNO~0B8;rM_gsLpZ`n~u*@G(L?@kGJiv(hpsXs$1vhNJ2=k%5!7!zv+G6dNZ0^^yi@D%n4b9#@VHEu~*CTh2REHae2j-{ru-?>b0Cow5EJZ)w8sq>GgZcY8 zjf8?X5r{W30Up$SEJ%zOzGG%MqXByZXAeZp@~w#&%fbQC_xA}9EspGmRhM$HuPw1a zEZ&^}o0{kjM*j4>T#!)X!Gm2A?dna*gc*Es)1DBM&(zM{Th)Is0qSWdz@*^gUTuim z4PIjFzpmTmdofcF_;D@VNi;uIq~oLM=Zee%^a3i+1Nm9n%F!eVAJm20j0R#ERE+S| z=>PIW^RAj{pV2?wexprXK!0$5P%zS(%=@g2NI{2}Y+1zrgeQ(jE^H-4lDIWfCe(99xs4Tkw5m#`U0RRxWU6FiNfNo>2dW=++iats> za6tQc*L1S*s(Kf>yk{%)<39M37>`Mb=)j(8HKOJqyIu!bA%LetV; z2z-j_x4g@5EXU32^)f3S503#}cwwU-17bhcdo01GWE%PlEqC60QXkC^KbgKh(q#88 zJfo1@Ykp|K<^c5e$Jw7rUf~bqb5HeySj4inIAcmMC?a$YP0?KMWif)i0EI`Q1zOdK`qc!V%9k4 zDW!K$82^96b4OiLR-ajRBZ6J}(&b`5Zf87kNJa4?D(1z$VWDwBXVpJS>YM&b|; zKiXtCpcl4$g$`u~G$9fC>=F@iTzyKQ@!Nuesr_y4s5(ZKT+uTPY5$}DPLWtJ=G(x3>)$nc@g zWb29Xe$6R|;)`u)m4_EJI{;x@sfIEB*YxKXM%8^QkvBy@?|oPl*1)K#;H0Je(c0;* zc>nfIh4TqnkVn3lGhU4sVar*iKfm0+L%Ua96Ji*7U4??KBV&{9=Q^Lf(W!b&3!sTJ z<}`B{OVeH6eD*!KYf6`a0hTUZ=1M2P!|WnLsq`laQ4Kz_8JxX-x5g%AC%n}4Oe%DB z)gf$3D}HBzU(naPZ7c#r`N0gD{&d#)JE6KNvBW$0DX#R*YVx2Vbph^(My+I4c4?r^ zbcmjeh-<2CVOZ)!b)w+;j>&y#i?Do8*Y7s#Dm+H!_@76m`*Zo`11=@(C-IaI%7qqs zp>hnF$6TV{1e8I#6t|<*{{EAF%7$+B?j*de(OQuzaZ@hNP7xxjc~=n-nUvOfV!gnu z{sAu44u#d;2L-Pd%tIafEG_97jaF7ey1ZZ?u!ao3ZF@<&$2vUme@#_#8>HT%@t%*! za{7={-mm)HjwOodf-A0pVbmHDzss@l%fe$@ssv|@DqZGk{Hju~KB}u|I28u;W0a|5 zs9)aa{Ap;)9)5j!K>dVecB*c^XiqI0X0{Jqkppd}Zt03+0CN#ngi0_WVFmdGV||PI zt_o-o$eq@@=v>Bso4!Klw;%b7!ZI@0sSPC(g(v(K zP3&&bI3Ehn#JTF%H}tI>s^95D*oixS!Q8vJ#-MzOR(-NDPY(Tdlt7j7J#+s2SVA|o zHuG(#l#EWy7BO~Z*ljM&RbCR`V5nFtOvO3)P;W2n=8mAT+3!V=_W`9P^>2>9%fh?Y z2h0&VZbiqqqAt(0apta4;)_5p1YInV8YvuIv*)$Y(Rh|}$su-l89{!s@Vly$?$MZg zQaTk|{WHFxo>2@Qif|e;IDeNx%xBg9{e}^~>x>3QZ-Z{6M>Sb92dC<^9pE+LP6Iq{$AhP1vX}_+ zl;0DF!oVK!@Jy30c!l@qVNW?9l8}k=1%ph4bJM2)7FTwWD<(NkgGXTsb=|LLW4?jFyC8P~5JHUv@cONjrt9^G*p%nwqYdqn3EZ&mM_0Z7kV5V5f{-SW6 zwMAz)CR%@-_MEQMJytXR>+AgCZhK=%Jj%ip;=VK~a^hDp(|EP!5o7PTg%WEnRcL?! zZV0iOhrfUKPUjVThN6IwW{xbp@1l=}KjB8qN0PFC*;rAO@g5UNx7>-A1%NGnkQspu zw~^hxo8eNJ-X+5~lu6Epq^SkI82v{^PMroVs2 zPUrr|x7ECsIW)-}IHi94<8IU+pY54=H!y~J_WH)Kzz330glxoqc$(-Rl42}Kvz>-B`ihr%lQ^{~s30YOy>a#*Pw;n5S#>$wgpRv5B6DNOxN5#O3BoUNX?q^N`B z1E!+T;0M>&FsNu_QV};AdN^m=5&E$;;<&4cZUijcQ7A>}BqXE8Ik$=Ty4}V4H=g+Y z7={Rp8SoF1U2$$?O~+iXKJgB2jaq#;uLBBm90?0AP_%qnSqtHcCgQ&~KS_TSVry1s zpB{e!QDHMR_v-)q)+G&(;g4hI&KaY{*Y*=d>I<`O4@%Ac;w6$}L1^JbZ?7yA18ida zauBOQJKBMI=eHYx@B$G}Vqy69j)tnsw!7gUZe2H_*eFk%cjK~{6y2`;vJkYTDACei zwpc7RDz&aqGIUSYFq}!1i6q2lhfL>I=rbx=GR^9 zcvNm&u57w&#Mxm9hRM`7V_3NyDy>Bm7Zm1HheCaYy z9`DES$@Q|5V(P1fdd6CTi5aI$7Jg-Z@7hMq{;~#=&4^m-Gp&`_i=+QlppK|(sU4HjU0vxds@v|?RIY*o9N${|?&vrDS!+ZTM zN|qK6pWJ2U&zC!Y_J0W3V?|+yt90K)0~d0mbN&|efLB8bif2Qyc>$-8`;-;*8AI14 zuW^AAo(M{!$f$Dmx1Zi!41Ru;lVGS~IN(}|OV&^t-HU~jV6Dr@5MyFeFC8zZPn!dl zN6{iLe&|?IaTgk|PG%owi0lc~3Pr<#j9U{ok#Nq z*3xHsz}QEjV)mCU;ddu-3O+T)A82#iGq8@(2QwwkgtmBR{OP{~Bq^IQ+eXO|b(Z{S zXry%gUbX5%{m|?xl0$gZ$NGRQ1TYx`A^Sz+TdpdVWPb9c^>_F7%f2Z9bgpP`L&`c} zi(bOBEYwd$h?|c_@7haWNx;oy{tUC9FnWeR&zWvdKn_`|7p#^9x&zC4R*7GOB&s7~ zSSxWz7n=bP{h5hl+PiU$n&DH*Uh=Ol+=4H;G94ME6NB~g7Q>D^b~X<1Q?1rVln#bE zqw<)G$azRS&gJ8t%n#J@u4=Nqc-0{;9Ph)}j#fpjZOHl&NX)qlHc zqS0dM?A*m15&Tx91i%^e2vrDwlF@{dO*WH5+UP`D3n3>1IUTLtMK9jtXs0(HH{A-+ zI+$)^LAS*AkY6-ViumTc_tyg9oxGOR*Kxlu6sP3i;Ejh~;(`zq3S)r7Jp$Mbo{y=; zc$=p`TEp~1z@C)r{`pKNB`RV6tn2WG>JZ$eq;m-eDC(XoO1#B z5ds@sbQ-@s6W6daf7{vh*G^=W^;xd9&L7;Nde8km??67`y(n=! z;iDRC>8-jTwRdyMx4?%g>Mce3DVnDvd<}xaJ#T4~wqL@rqx^v=AejoF86~~eklSTC z-$;Hnm&DE084@19ovyft>0>)mGcY0CqOT3tnO z-6v-GXU@747JUP#A3>@x+0Z9RpeYXcX(l#Q8tr0Si-MgT;2wb+)y5>NKg*zEj!^r2 zKE~9qOt_*bnQEb2uX}?E|F;kmd9a;$3^SyJ{WTIszdpj)$VA=WT4}0=({Y<5z1(QX7CdsS5X6j z1M=yYt}R;i9#{9mtWr9ghFG{?i=6Mn*L@xDWJTkcE4rX?FDbgs8tkG(Y;O3GXk#Ow z9kT(>l@LS}lJ4lN+W$DYAz#uP^OtnE6Xq2RIsds*)yDiysJ~dtx%I6ybc-K$1qDts z%)tuKf?Wn`Y8n7hTqm0OfF*$YDg84M8SQ8=C*IL3f7Y9r^R}AJnqKs$?oN*8F{Q<1 zGm54-sEebbc$A+GGLuM*diE~u>~Q4MaJ|mff7&ZfsVYm3h?W`589t-q$Oq8rwuoN0 z`Id5FpmGPW*8D@H#1)ub&%njr>;lRNO>ZAqg=b9X+^C28wnR%Yv9@%)>tH5I3+-?En2$g_gvYEqa7 zznIBU3I%!r7aX%KW2@aK*h!GBl}HwcT4qL*k%zAQOjL{I{jI&tlYfY9Xgb7h{6l^K zLYKl4Gxt%M`d}-{!pzr)buMR?d!zHOV=A=!UbhrJqbKst)15)C+vE}AxP|-OSO92w zKd}#*+5pgl_YFvPksgb50EpE}9{4s7ewX5{ z4EZEi(XuV`=_~Gj*QD5TWN$dtXK5_*88O@j_W>oFXevF_rA_*q{D5&pe1p$05Q;P{ z2YDlp>B#VWcvM93C5Iy|HrL-$&X?xM^cf@umte^qWA+)II~Bc|aOT*;jGQ#l#roSm zmpl8lh=I5LcyQe60z}ayb|3u(0I}*k!Q^s`X`%rRB1CJcGzkc|%%r)Bn$2j2+)$r7 zKjQSDrWBk~j z7Vmq{q`92{V5NwTY_z}<^OU+Ayh&YqVX;~Jq3+dwT_pX1s7zGZxnOo8Iy15RITumj z%nW(E-)Zs!MeL97Nbzw1eTH2lns!z#zH3)haR*^%T?;cabMWKYy3o$7C(x>1Q3@P2 zl=f}oLUz&DatAs(B(T8l^bSkfZ=GE;lHCLtqbEo$AO-UjHy9>Zg<`ASf`5ILeztYk zk^>Wq>dS88%PQc5I3c+o_horuspP-lZH6`#n--aHoyo%qOHfJEsv79HJ)@AhD`q!v&8k(iCx57R)EGd*CXw(rSd8gwlDfR2|Z97tU;&Ln&P+f zNTjoNsne~U?9~hSt&$@-Q*WV>36E2*X99kMP8ZvQN`gG_Xcj@eu-@@x)jF%iKz3JD zgl9&e=Ht7`o-QF)7rYR6@@NB&UHp8gKDhaJm-?Fq-^h0M0w`)%%BZSDV!uXfc4HEJ zeQ@W#Xkgh(WKjeLhPp5U@@h;B^T0i%|AL$tWIr=_@#}yd+qf*@*PYxbdp^DV{qXlh zMH|*Fu0-mTobEtO2*})zf9Um>0zow&w%D@!k9&D7UW7cc(9kiEDR~uc z_61avqe7SZ&Fc%x>2Bj>vfNs2;XGr1xq!vh_NzdG3G7IgNENsr+xj|WC+b>X+3#+f zzv^|CC@yWxVKdDdO+($zjpi)3HD&C^#|nIRkTDZ);*c&9sxN7XnBZ&JiCzmp`iP(6 zWl#(;4`rL3@|KHBF?Jp8;#XFn%5NfF?mPHIS|ns#R>-GvX#?=DIzmh_1_?s3KcBB{7?cs_cA-3e1rlPZbxMN=E&f2f6VST}EGLGt-&wkS2Z=}cc zAx{8Dn-XCtj!ck4%9Oi&vhoXuE#i$+k6SPKtLBRhefpJ!$2q}CqNDM}BEvc7eG6Av z%vpM>V$L&{(Rm&oVTvhTp}*jKxg)IphfC@(!|(C`CI#kY`FdMVny%#PE)d#JAe-^- zt-uT7k(Brlzxc}EGlPXaC2Q{&q~hJ~arXqFSrAB0-&AhmcSy%@D{WEs*j@pDGe2O? zenfz`WhCO1jizp~HDth>rHRJxFh%50Uq2N9_MZOL(BkzNV%<_yoX)wWCJGvP#ogVf?JriaI4A+ ztrky8${bah_QmO??5o#yzk3Kv6fPO-YlD4UHvwI}*j?hNP+89Xl8Rvp3q8ggaCYS% zZmene+C>DS68}Z%ryA^2pINYVSVGcmHSogno~yEaXHCNRh`k&TvUofF8%A1a=At6t zo7AsD?kf&WUgy=cR#vWlKwy}hKhc`Jxww3g&a7;oGraN$+3Sp5la1%@poKj~qJ5J{ zk-7SDzPPYTDIK(^lwFTp91mpH%t(p~VO;ca?t8=aBqJ=ocqfi*3dZro!td6j!7#1O z+M})AGb%lbyqI=c`mG`rcS{&9+`@5ma6{|_=b4PenTYHr40wM7NMq!Yl!j*$6gsB`|3X`Y$8at)iC!1%{x*C$2P1eB z=1}EN{I1KPsQId9lU(1%-|@#a@u)M0D>o97e~Vc4H~rh=w$LWH(kr!`CP&vIm@2{O z?gN#^kJhdDej4E2^+irNw*2-k;Qy$SS_W8eiToT)kyo-nafI+g9fV;}z@_2eE2QsY zR5sEAO=c-zmAi#CbRLpclwg0?@nw1TK^&4MZOB25g!ze?-@BN9l?V520K+yh#`Cz=3@XW0{@!zC};D9^}g9T6Rhz zG`JIiGgB=P^UvDhi03N5K4#sI`6JisAm`5^SZIEbA4??e=0o)0!dX<)RYqZv}x|!X8x@+=d|(Ov2m#! zViDLEw)U{v2jJ(|Gv$djT!y-{+Khv_e7z)>vz6GIf;SyZ@{!O7sqp!R2S3Y~KHa&v z7rj{;`h18C-^81Q(gCj}U^1|M~K){m|nExmEItUiB&_E68VR zqhhsZkH{c@EnD+*D2w}Z4(cMYw?TBzf!Z>i{x)1nhw`ejPLz-|&=5VYoMo7w6@<87 zT}5FNY8vcOHo2&0;#JO!v~35l-t>i1Ij!9r#0%Xs5?bFLHeMM9pL9}B3c48}fUACx z5DA{79bN5IM5Mgh9_$QDG10f>tfoR16?=)7`7!mgL{RrG(axK(KhfgJJjUBsXd$@_ zayPhXHGmP*O3@X((n89K`Feo=PSYi~!c@box4xn~x;s^8_@9Z1I30Vvib$B<@H(JX zERcb|+FVaC_%L1+{chC@zi;%C*h2UF(@8Ka>_xDyMp(U_$)0vAL|)W)OL0rVafCiN z`)>+nzutdxX=vzu%i*V{(9~{!y(^kGpyUfR`#k(L7`Ds!lS)`)O8F;Qv#D%so` zVY`zB2Ha&-?Bw*PcC+QF%SZ^ybnXJ3CfB{__fb!kgd4A!LH6`BLmUYlKQ3_-DMGeA z#WbDFFi0>EvdNIG&yYL4n3uyEEnq~RXfWEG=eg!z^i;B)bD6|#gNJ_9eD#)I`N`XL znAun^Fxg5sX;krT%ay)yBh>oZt%uvAVtZ8r+5O8NNR8wF^`+;$X+tRrWgzfmFq?rR zO;>(~bDGPXdW1f|W~MFNmA7l)N!5BjfmKyVkW+Ki$v!(F)NkMC=|+yV)YiQnS1F1< zJ#^IWsjPP?lqPrUo7)!m?cJ|nPfiZ%Fz99;{-2JPfuBSNqQXj=mheN^7Br&V@)7kC za!bbMUhv~U;SC7~ao;YWX?UQa_zs3MmHqa5n|yUkpp)?ST?gt;w8tiqQz()M*DKh$ zXZClSNssf66X>yrimzg;ZdzXyNNa1|mbd10D&Rt?KkCz2D0F6}m}}0}TiZDWU&=@+ zCd@L%(upG*VcF+8uSc)r$v5Ho^&^34OBj6V`|!EcdIA&?s{}p!5OyuuzNP6 zX@eScxOo?x?6;l6Bk@G=@ph!alx79P?KqW1t;%;+{{gmXMyjRFuYWgq1R>eBTh!=MwCJLFKFuq=JOd;I zlF>BqOZCl0C=@65_H2MrsQuNQ7xy1b(RyYcs}XON!Lw`>-B?&@JQcu;{T#MQWj@O#!U@wbcmiX{_Hk}0(I=G!flZ-@((qy zu5Kzv-4x&VmA$cm)iW{-n&4{^DSLsm?LyI9?AFbUi3sgi^sTQ)rTv zbMkI}=v9=7c9-rQQBRbxO8!g!b1&i+6w>wd)o#S(WEhA`!dGoT>azRtUk3&D&26FV z$V?KYCHrxQjT|S)FE5DCiC7O>_qf!o$ z=V=XTZgQY6r3pzZOrFJATHdHV!vjq_nNyP|uguT6hjsEe@B2l6RxQD=uxx(%-F!t^ z&XQd{+qJ-!kN@|!zubuaZ7H6o1wa9nj=1n+{K_(Pj}ZrAC62^^){e{Oy3lp_J?Md1 zvYGu^9TZ0^KVgojsUW*^|u^1$lO`_U%Y;-{^TTPVc9323Ah&KRdaEp zl+TkSJ7|O=uAdiPs|h~n_(RTy>ejFyC)FKTRL_aN7q^;q}6h9Z-2(*0O!q zUzyePGJO%q0t7 zLa-?G@)TcA?W}GoC{pxE&+-Yv{c~oi2F1s=w&knOPB57fm@AYp` zoDrKR&(J){f-P4SBJbmPxtxEmD61ZS~CHzq>hBiEBr>w-jlCjId?Vg1hGC0RE2$br^PZ`jiFqKTC8IwR-zR33q z)D)>ZTi9n7ZY@o3K4)^g$F-VAvJIo?tWfRIXE706E=3YWnU2dO4-EE@vzu-fp!z7U z_Uj)_?6H(7zL68=j-L45S$~AM+7k|7Uzw}ua!js`&e^4Ghd_#qf+1hIFVRfpW9y6_cSdA~2zYsdU zTvTLm7t24Pzz7*XyFK#hq41o&;5)MB`kV?LCApK7Km)Tw}PO$4C>XV~uU_q}` zulT60tf-One5j*H6p4e9TAL%Lui}ljl?r0QvgyJim4B1n`sX0=C66VWjK8@b1-R?( zw*36Cqc}c@`(32jNos%FZ#J^NbQ^zV=$_u^vqK+Tn7{P=%&J*^T{<)ec;2KqdpI?L zExf6*)0yGy`Up718P=K;-K=^3NS!ZYxiIhxyuU+<;iZsj<%bpHr=GvP%(se567Ap+ z`vG%PMkc?YBd)Qn1O?0W{|j&H1Vw2S8MWTJb(2EDP;YqkAKHL9XA~ z!=!Zdm+`vCuMjG0H+)Znjj-^X11z(~?jyDRH;<+D>?W5*F!#&yt@F9_#{&4EC7Pua5l0 zx$QL05rE$#Uvg%!^oayN5RMuBTAl@rlMXSC>!a9n-BaU?Uvjc&_cb^!s{7Nm`HVv& zcMp4*oV0g2xL%gLtX_f&Eop`5ftaK<+h*gyOM)ni@LJ!p{Iv_^_m=YK#Pmk2y`m@b zRfR><3&XRy^*-6{jNT%Yg1v1wu3zL@aPz6TRIt$rBD_fkHnlV6Ryafn+D5O35emLZ z|9CTZ48Qp{Sv}p9d^Uy8baAw@KrS=_xo24mk!lf5H%b))iA@e18JYL}=G4f(($R=K?! zQ4r6PNB6|a?6|a9L&axS;UH3#LEKkAJY%hm9be;1$jZkqplL-nY@iDWoA1_#XvwuT zwjcXKg~~rEyPcLA-dcnwR9zq#5WkEiowWZOU0z%Q%P`10{r2^K3VQijK<AcBb()j7j$`hVt$v1#} zH#IeRop87gc$#uL_)rqu69kW2`1xXCe&I&zIiIFmiI)d_zsWB8G3TrFzqE7-OhLY1 z?nKf)Znfej(%*P@r|j;z8D#HvUJfGOdZIc>&9v`6rw!bLC2# zNM4&@GgDZsS!K4siGJT|gg6o};rb8L1&HxWN>uywmBL8}a2M%76cVd>lxz;}{6$fb zTcsPb&g_IF51fa(U+KIm>QPWRWcqiNUo0_x?06M(-n!(CdU>PSeYcX_Co5E+j74Vspw(HB6sQ@XOZVGyZ&{&K|`pS?d; z7yDps-gNVO04sgfp|!Z(FJ3!yX^xRiQ_mS2mV&YO&C)xGb!##2b~K1FX#2h;-C943 zPY`FOrweh$gTm&k523&G^<+;Ozs+&LI8*m=v)7PSkMH>JTI`%!Y!7_}l!&FwlT>M?>Mj=H*o=M$xvvS{W zN+2NOPQspt6Pt}rs&wz0Hy)9~IqRX~(I@ltSx9OhFcneP z6b%bn(+eNCLXR8L+>0*w!ed&@YQZo5VEJa9e(c$jxCkP3j!~n>_$3<__{Al{kY)zbgFoURs;#HGQ{$=ADED zeWoluN2~M__}CMcu+MfR52laipZh3$B**O)35c5wQ9d_d_*zC?_|K=L{QVv=!iNt! zuKG&2sN;aR@bmT%(9H}4s9R>hF>m6jX_AUID|f7{C{TX*xtct&L_o+=(%=lR<9W&N z-tMVvT+z>bzM zqDlI8UzT?m>c!xYvjFT4H6G(`uJmniXwb;PRB;z^qKd9F~-@V+76&5=I>KYhc! zI@-3ZvIYH{RV!Fg^Vf&Rs-Hex;~w^^Y!tjyze$6oO*U&FLLK^_ZWlN4NZpT{Ti!Eq zS?R1-jio1enbdbfL3}1zX<}ijK%0`(3844>NvsC*vUVN?W3i`hC&l)=lZ&pSV^bYq z@v$h#*Zo9v-4=3gRZ<}K(xvG1!0i_MT$ch_4q=6i-%!=X=B0JtxW?I%aKCY2FU)?oeEPBbZ=0%=B92*z zi1H9$L@P44P9eyeH`HbR%&dIZsd*|4cpO_q-v=l>D*ND!r@V%tlPtk0GgO^Sd+Z(lw=w>+$S7Ut7Vnr_W^ zxg0I$WpKULBe3ZAB!OieVwiP_o*WUgellNYSs*rrsUGS3nYZ=rw|}!f0IOCacb;?f zrG0rFZ|lhbZuo-zc{98+S#R!w{AxVh!I3uJf3)#EW^#n2M6GAHtPn3cW8|TaScj7A z!3I8976cpP0r(tD{D=e!oP}328$C!@*D$^M@!xqA@rD)nFIscoT&65;dsV|v<`?LJ zwsuMV(uhZ&9Fv{DbK|O2(XXQy=nr`EmVfjHryZ9>@9#%@td~8S+|?TOX^s%-r3LN& zV9I0k49>K02??g%-#pCTi0ZHtX5-Lr`SV-{`x71#ON$F`A-4c6fS2j@2d{M!>zP^L zSjK}`lXVsMs5Ja`^;=hE@jFVEf;cie@85Ue3i+ztL@zpvES#8Zuu38DbF=`iNQ6Ui zhVeX8`_E;cfup2w&*xq4nKg2N{6m4`k0Giue^&2ZQ4#73?J_t#I#5u)Tn^YY4Np<$bGZCm?A^G7 ziTJzEn%>OZ8B!jcLMCIX^b(~)N5NjRY&9tP)hpQCm)Y8fsMDoOam7DvWrLJc&}g08 zd!ce}aO3uUlYe;c;-T0%I{2ZlEl*8BSNrM|-T zA{T(R$K@OUv86$zEnaDmW5k-5*(RE-;!AlY&uwUDK&A?}Z-d%1V>|}l^6o~3={}gI zno;z)AfC6x>CqD8Mv@;d^W^E^SdMo7bGSDDweq+Y?;)rNq_qxobxuDQbdlO}MfH_vmB-IGHPt$u7 zMT^`UFZmb~OW(GOn|*P|V)+jVxx#kg7V@P2PRZkI@{-qgLk_{1rDsxVNnnDc;xB_> zhq-~n1stHOp|~1Fx39rst^hif*D35pXJJ*=jQ&ie8tM-xqvIXE_Bp-~uVV0rul%-> z{Sb}?iqg@g~MO!=^GRMVRpL=^Zv+U5sakVG_Xq{ z&@w5AzPt)3-|(VSX1hgd>adpCxZ~E6axg^jL&U;nFwm! zb@>k&UsQ~72=|k4%2i2h;VM}FCWI{0lOQO?kiXfA+f4E0#Uk3${3(? z;h*5C!$Xwb6*x{}KYERa9v?(b$PbIAxVv&V;uht9SKod*LG8y~Wx|8)a_}I~RR3=o zScZE>F{P2nNfY}H)XVar*`4zT(LQsL`sgn3Z!RR_=7!9_ZnVPV&)>3nnt9q%ZQyFY z``KrQ$Xsd8nY2y(0Yk%;b@(cg-TD2e-vtl=n5k2OW^KsGBa5r&;@t2~o+e;5GCb~$ z3Xi@~;r~s@pPQ@xN+}y>65mFMme;P+O)?e~Q7FEpe7?z|s z)+|R*QPxH->j&a7n~#Rrma%{Rz9h%Paa9=m2`6lA`}r><1T=Dy_9E#)^phDfNlhtB z9Es5P?FcAZO`I)48EWi`Rau`$JIjQy-@iNw5;rtxFkkBiF8GzjO!3#$uiHx^7b?@{ zA3fuU{hZEt$&&5*<)^|~snne8!Eu)cT-*qF9X`sr`vG`XvE>#F&L1GB3-alufIhok z=Wk%G9HUN#b^3}q>O9Ic~-vU`DF-gNNZ9M48C@cFf%@?p&;kiN92VFjUznolXurRbDTz^qXPxTJjeuuti)}}ezV94EW{CIk&D`6 zLN0pH=c8^xy zmpf9AG5w4SPi4|--~5k4!uG-5l#P=h;;wwO_?rVj_krrnLLrCikdvyTTTc)1#&6sr z{#Sck6jOORO^rCP%&zZ?SK+%S?&wg`GlR0SA5mTnq?$ah$ls#l3TQVjwtaQf8*a5X z%cL~VXs-Pc_=0!%Fy4Ux~ri^W00C?y5iD|%Fq_?0=T@i6;WxQ*>cD48N0B6N9yzQl#U zA3h&MZT1N{$GzqbuKxxRXsE)*Y!$0;kf+615sK}m>;CGEnFYGg~h!1fVREO;s zLlhQ1%T^7Lh)M-153NwYT`%Tkn8?xrLi-JcRN0fRd!j8 zhV83rxfF2n4}=Su!ZWJbN2ZF495-!;oqu}}A z?w9Id+Rr+N#U55J-&BpcFtfZhzE(b*zc+!u7VB5;^p*HO=jDc8|Moi3&9JTTP^-|| zTcoOWhkWoo${RPkm5C6`6%ylBh&S!8F|g=%2g|N&dpx>Z%ClEg1B`=B&GnTu<$mv#m3D3dkT-J`ur#@yZW1~_@b ztExjn15Kl#D;E}-L7Z?DKl~-!V|{SqcL^UHLt;FGBZ?ZS!q#=rCnGOG^lbEMcE2NHUR#gSQ9JSFl`mxAWlpm` z=Z@ucVA|b))00w0@Dgc-$B>w^GmwzpS>Ga+ikT?d@TRzX;#w37`huk87>Woa@g_I} zU)T2bd)F#b5XhYDSWU=D4}{~O7AbC}WhJPg&dc2T?jA-@|-Ck*`W4(B)5eAfgfs??vy!3$sIa6ag* z+THHtzjg1Yy98ks-kf{Y8W`2A=ct#Ll5b{d+hsd;44eCuxU+sfv=W%{DSn z_F)ucdg6G>L_Xm;)pILS8bL?Gr2$lgp;&Dg$B3&WCOOb3G}Kz`9OEhO{^Q&|KgsDY zA{(V}6pvwyeIQ2BfSE}6ZG$xqp7Z#lzW!paSEaYFkoNo(pozG(K=6a=Y@vKWXouTm zqHKyV0@73a`gSDYDcTnO^n=nwuvu-Ssa{9fA-lUY;3w1TF65}S`uZ^QUu0+`c`h>cUlF}WAFmBcI6{v~^Z&a|zhJHU(j&=< zc}*ltg#C)s=RO43B-EnSh9vX%E?;nXGVn<^zKnrP?Jo{q55E}Ciktily6|T&2UZ8eZXE{nKk1v^>Vh_aQTWxGfw@251;xWx&!eb>F6H~g(|Nw zg(iK1lMYesUKdW8$YC=re08d5_vRxfz|-;BJ30o=d7pH0cS*20R5~U}siZraw(w0V zKL%>fmgV)UXfw(s3^wS3F_(TMgI|7AL3wZus|=WS0a2^< z0SjL<^;%I>gFwCQwbS4A&+X|=bH-SWi!1US^7XV^(bVxfrYwH|T*UU-7j}ZMN*(#u z_rG5ci~RwshQfbfTY~IXRi~RmFY2e9_)G#vE5DZnG1-_tGQyB$2;jvo%M&>AVIpef z#ggZge7wa>kD4_diqjj4JEXM`3;r+;m~ zhy0;=qC>t#!8gjqD_YuXzrRG%!3ftuNVhT)JqAzu1QF%W*`q)EW+}}reT`>9HN5uR znYIcd&O2p@jAdlFpGY>*`2>iw*_DX8qkcLS*&As9efmjuJ)EBZ+6WNS(mdd4 znb~rb;CN5or%LBp+J~e_n7w_gUM}5v{HEK3S6^B}TmZ?yS|yqthAst}Z@~-qk(Sk9TW{>X80qXEo^h4Q)tb95{q5P8 zuo}__0mQtNwO!WPlRaB4GbN!!q_XWR)%nQ3?m-cI+_#5ukgYC0_I$33n8Mnr^6;kmhMSjdO`#hv)@8EyJ1p#*j|*jcwv7Gs z=EDB-u6W<4KO;1Yoh(TG1gL$F^$^60j{KLYjMd^4M0c`-lidL)vwQOcOXVri6o;s} zi2N!#ZwZ5;Z5UNt_on&CzCctjgLQnz)PSKZbtNM_;kTH)j#;|(SDl(f1*45E)|BxU zuXhHXX??mU5|U~Oe(F0hbn)l7^>-c6B;x!DGiaK^!+ujb=Uil!J{o=X!g8*>Q}>zz zQoTc>j2gRL?!nbH=c=7`a^dK6i#Y4}DpuyB@+v@#(J-VRobg-DP&AK4b-}q7Y^r+E zaZt8?dLm9LDqcNewciaRLMAmS~kk zcVaeHZ=Z93l)Y%Y5%A^LKs>EN494I`Wj=AO)ei+n%N|kU-jzK_HVFn@EO~X%A8M3c zDv*ib602@{Wd_T|u!~|hWYw#ejgN}XBCn_xaK=5qMLdDs)cE4_)E^{i>r{H$Kq6&P zu-L5-4z3R?CZQT@%uYaJQeA%kSZdW;(BzIF+pWy%zvbnr;d5uY{3)x?!7%<$t~0pK zeBj<06kwkHvAlsJsN+X2o5UHZ^%A9396Vue=uF$m+1EYV577}8m^e-c*d*)sB7_$#zm)T5K~)0L zt{@yR0l@Lts>RebJOZ;06uF{{vM@^MVnJwgi_U;D(3QGU zvL=E!OI-9G7-kh^HQqZ%S8+-S)ufGQvP0se}a{;Ey z$*MgADGNr=nmxHqQrQ??ktK5t9qiyzV0O#f9s=za?PPoZaogo`y^6hoM8|%5QE}&y zg*3nBEs3e%9p%s@jxQoj9OZu>u~MYYS+(UyZ;`daP%R*o~TL7 zV@bqA9X`E29kyk$K~_B!e>MFBL+ZVwyBrBf=l*##p6KiUF9yoFxXp>$r5n4wNu8`l_MaMS_dvvgp(E&m5cu|J7 zS=41$--LzQF!pdN9Jl&8tHjMLTzMQPWa%;VLbtOI4^FaW%qQBy^k^V z8wpsZU1}R#C+tbh={}0kldN)z*o(S|e0T35_=K}IZ>sD%wqV3Y_!0F+g3!KCii-WJCRfK01j~< zk^KZ%ygMNfZ*R~zIsyKabr-olT$o>=r%L`{(wybF?DcO^pFQH>!auEAAAk!+<>tU5 z+!G>@fekT~yYT=j9<#+P=k$yoml!OGG;VZ>n@h3KDgt z`AJb1;E@Tp{k$j+ode1JcX<^dI!d}l^Gtnuu5?c!!2A{~q2bo!B$> zW^#~!jV_@vs0<8gcX?mzee&6}WgM7V4~l>%aQng1SRg+YCAXRLm#o58s8qLtbk@Xpnn@)!vUry&uiPBiE(0oFrSZ z3*(V?JTQ^_6J)DTFyMQ=w*uYDML@1X4e`tqaYiW4%GIk8U7v9fsFQKI(iFt}$U3rf za89NXldy}E{S7_~Gh#{D5CBo9+|LXl=9ufHN-vM@37~?rEJc^q5R3T<3P}gPNSxoj z8rC(kiB4PIAN-1O&^8B8y$gOvnUFoZ3{E=Fv5%1K!E^_F&J@KuumB`&oU?0Q1=H+RrdXew^$}p2AhNH1K!EdQY zt``=3cU^^BH2V6^lK>0^IGBDy)xZxK1cuz1>&($=0@J zT}8$ou7%$8b%{XkFc!qwv7vR zp3n6EQj`XMf0tOi_~KIS<@1@_HyC%`V1uqi2ewdOs@6U6z<1w>G^e}0S0LUyjk+(d zy%us>y&tpPiY<9;)G~Mw4V}V%D{NSC28m@j_q_T+bd(iv^fsDn*L?xelEns+e z{_nCo3~{`oMx7>nU%`R)2zi6d15b+sg?!C!gqqq-gn(K~5^r<~sgzV)lq_(?pEG;{ zrQ3;T_m2F~w0dNJn_^NE=ou*&z3(0JKIe8zfuX9DD`H&bS09_Jd1AiO=i;hTg$@p5 z%Q9d=_09C_6zms}w%PBZ+O577yX3EG(cv^n4L?-wEFW8<5T~c@xE6fG7&?&jAxdk! zH6MvG*Jb8keavj6QHlo$2JCziX})I`ocu(|IIC!->(3WnY8(8<)?Hx#c6=0Cl)8sT zsr4x8=NoiG#I%o8fb_H$TzMFJ{c<-XUDGc&;4_ca#z{Q7sZN5jIG;3kJw&#s3|D_Qz^G{j~amGjHK-h`KDr?n_}wA5 zOY{+zK1x|VHM00_nR~S20M&IO*rO+dB|pUGICh^%o(iVpuC-IQ^k;%iHfNPwtq?9# zO3pqgSheBDrH_@@!uJ_aSGegNL%S|vVX}`2V?+i^(@moZyb9f2zM*P@R=2|L=KU+9 zX{suEEZgFh4{D8@Q*qY;+vh#kzm*dJ!AvW?T=h9w7{wbKmi5s|ynf&okMvu-=@*P0 zupE_aA7;j6K<{9%14gZ{qT06Z76(59wO^smHg}{MmG}~~&UZaJTWeLg642Gk5pG?| zgV_2t4qh56*mDidP+jfpMZk_UrPHv2R@`Vdyh19)T$8#jFBlKlHdTGgd=miPMGx6t zQ~aK`6a+SuMLm5Ej25`D=RdpP(Vs_V@VKl;HfGQW0=CDn!v9@bObXM~rU`H7XjL;g zJ!?8F)Lj;&F*0=DD>7$S`sBx>Q6)q66Th>|h^!}H!C%*~>U7OYGttSZbkwYpWt8Y08HpSxpjaHfZJ#tDW4idX?@3fs@h?u^nF%SE zzZHz&Ula_Z;}@(gL6y4Hx~ClvwBymiY;IPL4Ly`^emahTMEr{*ep9QFcf~MB?>4&SlX)+V6@O^%A64adXN-#!Ut9QCVqi6wtvw&IuFX4NC z0te?Sh1YT65@Rar!BUqh-YN7oOSH3AP0F|^zM0Q0)*71O3B$^b4Gn%R=d!<3R!s^_ z9$8nER-Y6Na2NN1r)Uqh$jGn_HwH6Poa}OGg)|m?DxdvV`m!x zW#v!$IL`QkJgx1?C6HNd0s@?`yP$Aj6}pUji5S?r>`h$Z;2?xdhDZJMdU(jcv$2sa zvcwcgy(jZ#ZCDHN>fOg4PT%UMm(!%KQSs)bBA9Xhv_~;^mh~i+b}4c>rHVQ>vDt*0 z0olZoPYrNM0Nk8yk@boX+I}>i8&z#K@XF0$8>qaxYwI1`1nk}fA{^5o-;;R%2(4)a zG5%$-ppHhQ2P$8uioV*7@2cB2vDdHf0VZq$YM&*GJio>E{mvhjC9l|#kTUgVEN#X} zmiz@SCG6wzW=YWtR7(I;zQP~s)t6v1%RkdZlKE}< z_&P$9fWIu!+82cRE~0Rse(0P?E7%C8@*(`lf0+>oT1I?2;>QiLJ$B)+KI--El&_Z#;1b!x(YUXG=Keyo2r+m?m`W^|!f8Bap_t8SV)iSNl@8no(4 z;B0LMd6P}8$ide=-r+a5>Ja;a(PP>x{N zo9WkfSbTed*@6!S3wAKQjYB|ffeZ|dgwO2LhF>xuoWp3LDcibvNg=WK&xJ%B5D|o?uMj_~^H{EQD-r|82^xoVOEwnS*9B)KO z9=xG47)Pmdo|f#!l3lpk5mP@AdpX+eGIAO3X$CY{co42ZUlcgC74T8vu9)_}n`!%a|3mHL zLtnuC$;S={MO`p4$jaXL!$Rbv9&p-S*D))6)SqAueoJQezpo46bDtj@Is&^vEUF>q zhQVD;JS)&SKTJR&_vJ+l=;mXNy(bWaaRX=UE9Q=TEn$xzRITEg3=qK zSwAu&h{4yZ@oQXqxi%gHWL({t#u{1sMBtl!Uvl+(d;n>&-mL$>BaUe0LS>`bPo%G} zMl^BJVT%EJ3sEOG!rpk`c(M`m5EM2|!zx0P#54XL5eQElJp%NBIT?HbWc2LpM(?ej zcMf!9zl}%4h0|PG9dwU0_>zRtFhWA1u?Xs8)(3EhHDT;=du_YPfF*fh6Zn+Hwr6Er zZf+s+1<6V9mu5Kx?)lj#q6WoKvaaoGEV10r%N^9|V>OS^|Gt)teun+`oAsmUL*yL>?W|s z#$wwt zf?!?!_V8)+rxf1E;EQQUryFNte3E~e=8 zaR8JO`D5*9$@BNIR1gA4ApG@}#GNvi8vJAEy#OUF4X>sxbj8vV|8eOnbZ(i+;M{Au zdisqXk{9lmY}@sK@gMOE(j0@jo29=O6oK|{=S0P>ZPFd?`?<#Qd1b9)))jN-U>FEa zh$r7;5#pX8s$#xp?S2GqKa&oAD_Opa?|QZY)HGVM*%FF4#(Gtn>>Qjf?^)Aht9b|S zn)6+$qs#n1Nn7e;Gh|~`gwP(IS*Wby(UWFyMCr#rqlCFtNw(t zqS>wpt8QcCkj6}yL%)~MLFzSTlgh#MCjJ6TyD1Y<7f zc9$+5;(GVjOog%l+ZQsWCP&&hoA{6PgffkUFy-ay$u%By;rQEscK!qhq2xM!PcLb@ zPk9FEa4Kce5r3t@U+#O~ryWdt_tqcY8;B8wLvqEduy*kL4M%&tz)DH8reipTTsQqD z>P4E0Ki%EHJh!&L#`a^B?hI}a0UuY6V)ep|YD>yxXI2k*YmIi9F*^oZPLq_8ccp)} z80mqG_0H5@@Dv=lvH~^#wY8_F6ikW)U1g*^? z6?du^LpJ6v(Gl6D&zwBqU{k-I0q;Ua*W@Ps8m_jDZv##1(Xzq%!&%*uenk9MN`zGT zZe61u6QAj~`QxR+GaaxICExPIlRCx0RD9Qe{|H}Mw0$H`zCgrore81T7$vTc^W-34 zjFt+wt8n~gtQU(up?T)UuMLt$o_F7LUq+yAs?vz z!jdBvdREyCHHLg=R2&!Gg#}IEGgb&fQZ&M9V(SHW;oBO!P7rj(Pu!S!(dLb#v%^j| z+cz#kIDspXai-<$bNz`^{MY83;sZwYTOI7Ia37K~_WqLH*@K1~_n}n+os;qUqf}>M zrX$|gHM;ccq0)Jm&>^sc*q`2?KrE?h(CA64ko~9^M7aJO9DNQW^_Jb)@qpubD>4xF zkQ2Y`Lqqle-UMpTPo6s_Gk}Ab^^wo{$g28YOZ$ACkg+`)A<<)kV z)j!Dy2Mp#w%0yAn=x#nxM45e}XWiVmejSLm4PLvJ&(p+<^_2FP^gf#V>7Yz-T$SrP zu4CZAjMouQIwyJ;4DnxqnlarnI5IVT?YNEcsKu32#6*V5^E>@^o%$90z&oEPC3xPl z-A2(>>KjngNA9NBu%~!Qcieb#l}r*qVMe42F$}MH{!NYZ{vMNH&Hl9wWd-9_X^~q|u+!}Osrb@@j&6j^qKS~$TjiTEY;#-G^EWYO zTZ1!KEC|+FKzGSh&sf8Pl!EL%&tyNH#~0g^QI;@F*lAz z+dmB%tH3Kuig(^E+W${|o~mjJix6#t&1ih(8&rXc$K5YsWS7=@yvoOQ7SZ$jYEn{l zgRPW$q*wZbps2l5Br3JGQ^4trczpHW86F0etzaXDmQcu(zrcr!d;p>l`UA{y6 z_|09Zq=``?0ZFEEe}ZlxB+t!t^MM2wN7|@4P=)ZF7c4wi@dpbDSBH25BSV=+gw`tj znDQdmK0Yqxaej3fR^()y0(a1Hu1iy5e1maL5+QLR`>`ox`!jjf{nO8jKujrsaCbXu zcQmZ&$RbL(~{88poItel6=2%DcQv31xlU*t&SASm%Wb~TEX)&1qcPs(p5O1mBN(DEqy+~WD(Z%EN9i| z*Tk@bJi@RKc88|&8A-?Tqh*Jq=C+*GC*Jj+kN@??o5QKec?~J}4-~TbaJHKs?6-O| zTIwpqF}>B!h*sYCFQami2dC=EV%UsJD)PHOgayj=Q-&^c>v0sm2RHkoFSm%+iL4d0IW<;3q8^lhmSuc`nHD7-hl&ix4!&G#vjym z(tjU%xHTTxP$;#5jHA9x72b~1p#i>UrUMmqVcPU(h>MS6;XFgA!x3HOHy~!jC2h+h z6kDWA_MR>6>tycC|KsQ^1Db08227VU62e0_NGL7RNS7kg4N6F-)Ceh+2I=k)5a}8z zAl*m}1gQZ-YK+Bs_kZ{Kd^$U4JLkTC*L9gRdt7LvS>EB8b?NPYj6j%b`@MvtRLlHi zM)6rhGY-J!Cp69}2uQN6)*4%#tT_V_s?lKnVajWq{SLP$ix8zUknDm-Xjzc1avHo$ zFcA*76qTG+7-4;A69c|32&Q0;MdCZ1!lV+W&@@2s+gemG z$uupWzO>hE$+A&@A)68g6@TyGS5W|UF7EfH-?fI}$R=Xh`HvYB0^38=Qc5Yb69(@;fZacl3c9nEZx*!u>bH`!SMt(N)^1rXlj zAekC_iZud9?G}UBweJo;V@^+b@uWY_IX?f@g!vg|<9-!D^Ul|{9}D@V@+o2G(b?@> z!^0`=l<&IomJTSu;B3ubNq5wGpS4Gr6l_ z?Cq)kDAzUHLAmX|HI3dV4WIc(j6)3XJ>}S&@clTh>jR@i zd&RMusl+D?6fFayK?UK;y}@T%fDo1l75 z;nCZH+-+e+h>_XNb+OX^^GbV`2~72T^34wkbZ{AItO1Q%-5v(bWF_}`fa8&}I}|Zb z3>-2OJj4chV-O1RRuytY35w(u(d@3Rt5DXXHCIMCEPoHw>pK{I8lwNOzIP8pFt;0GubFX$`Lo=AV+?MBRfDepUGWzDM>Ym5$C%kr%S zYs!1=t2>Lj5=iP#118k|z)qa?^_{r?iQQTvV_z#I_~TR0aBwWX9NFk(TI1b^gB!#7 z&YkDMGHEozz>%WIRE1aIMJVH2X^V26`T6VP;Fs9n*gs9~#iN&8Hh)XcOwUbMg$L%q zOO+<97xi?-7#jL{)0Hq8g}?Rza!d1MU$gznML%-R zHm`df+pfFe{N1VSBC7zRWaZ;F`R#_rYEJ5qm@brB|~;LN6K z>tucq`d^P`#rFZ3rL^Zw^5gEBi68G!iCaF=rxT z6!`XM<>AioX-2{8xl{0<*s?90opE8@e3o9Ryb0JS7&cO6;MaMy*_*1LWjg-{!**v> z+)gXTp{_iyoPVe^AhUWJKe9&H9L_&hZ(a8zv=8~H^##L9(XQ;WjJFK242}{x923czx5CEHtAVE zj3q1O;FQOJuyYXD|Y}A7t#_Afk4m!BdVL0j1C2a z09!|dm2ar4G~Guw`sb0WmR_l#tEq;}uVZ6@@wp(po^;akZNS!fJE!mPB7% zw4*?&fWCp!gG|!*qy0k6k-_N`j2hmMua!k4IPI%d46w+XBQOteM*QuxG5uXZs6HGK z|UXa4zWKVkN_(?u_)rR|%!wifZ4W@`G^^)N290`6|z(Q2|{`qNY)@C>IM zbkz{}%(k-~F}X0jkNiHOpYYhFT=e-XGno>#I*5999zj z9)A5BkDr(ktV=Yv_w-S-|eI>UT)A?g$~4`DhPF6bie{iPF@gJ8Q#qO%1~Z_ zS@ba3sgWgy;5YTzfZ}Zntnb(2VuFns03#NkPULndH5!U~Ts*`zc5Hmt&DNGI9`H zRI$!OZM)59lIk!1Rmz6Nmq3U;jdYQCpSRk6=2~EIc$oj*VX_I%Di+5y@t~8x?JYO) zA9{PF2yBf1p?kg<8@a5Fk}i_5ow;U{>`t4Bna|+qfUZxUfDhTWr7lHtgAcGg!rbB# znLgvP4HDDqmSuzH>-vO#Ry%P8yxDDlO; z=G4`NT~L(w+;;%?uK99chVXD;4q!9<1DXHrW5JiA3@zWpWotp6HLRW)nj|X4+jvpJ9i-wuW@eTAFoMnvk0mX5i*gmo)r%h{hJdcR&ee)%@35k%n%5w zEWW>Z+U^RM_YNJaxuJZxi2m!59B4;6#d z<1yx}8l?JMN7kTXk9Y*{ByND6Gwng@GeYodQY{ z(N6ko%9QSY=I)bmyUp)NOn%WO%~UO)7)Zh1@bP`9B;Ns6WFgXXJT2iie}jkg*f0xEAHiG;B}D{-Os#%Z>jt;OFBI} zG;@gi`S6mvgcQ7(5B#zm3xBt+C-Lg-0sy&_lT&H}?_?~7zO|1~l1yYrwxjDD?DqDwCGHK?HBdkA4@H}|OdEXiwcS_H$@86%EQ|w5H z5=b9jMlY+P!^;vl7B$a><_Dv_u{JtDHv*fZWBBd*aE6{P-1p108^H3i6>=LNjsc^; zTWx4UFq4?P_P7KYl%#@G^PoBdv=+(GXBQRHv24C0g+roJxs2xT11;HnT7v*Ek;!?6 zvoSgYE$lWx!pCs(MjCg}$Rbn`*HzKa0)9ljoxhNYE%D^N`oID*KNriv+SE{%_9VRj zvso6FNYQW3vX_8+i&J3X&YmW4#6Vs6Grlw2%h1k4rymqgJ9wii_#Z8MRrkJ|FIjxd zegg*zXl8++Vvi5J35@d`65hAXSzg%^@=>LhKH`QoNq9XWgQ5TMiC@=RKC=_{f3q0O zG-^P5)2S2*c+|hKX)DgPv2Z3bV2+IMuSY8Ze!FmhXt)WCgeYCl&pZkrk|Sv5v62ii zxTpgr&boLpbq#XaSky0jwQDo?^5%`i?V(tIAc;di+ucj^IkVuo*SG<61I1lDC_6rh z&IC}GJ2)lii*tp{aN>wOz!wHg$@PASZYFPicqfoH7lfOa| z_A+k8h{#rv@1~wtwakh2U}|k6cQkin`foN26RN0f2da!oEdR1!44THMa5M}Jw6xshJBZHb)1feN@;p*Ugp6Smzy=~;_GPO=wkKi#CNrgMzT z7m&o0DJ(CXm|Bj-C)%4M*r_5`q5DfDwH+`Wkl%mWh%}Mz4#h8#!}+{t&-fFBbG-l? z@UdOP#AK*jvc9LR*vT7#4FRteJqtth@-K23tjr>Ph?gfpWNdEPzKz(zbX8IPyU$Y zfnPvY>fXyRe^xucF^)0i!6RJu{&>Q^%Bv<#n>`yDk$2mJ)TN4+GG};19l~6*GB~TH z2veO}WV7V=%GO!{suyKE*e}U0c&}7CfwBW#2X$PMSi9erg}?X;7YJJ9$lpQb5VOo8O*X*u_MH^aw4k9(!UMK4HZOn1sn{<86aixkiO!TJ$ z#N8Y0T>AY3`d1O7#sP?t)yIlUvcKYTA3l1MBEn2s8bru!{&Hd4Y!}=&m(mFg*@Ma; z`O{sl8YjIF`p)O1^=`xnjyZ&nk^StW4>43>4Y<$hu0kCVo&(0t`rwo`_bC2(ci@8Z zLylm&GXVa1FqFC_HK!^SYe-?Jc3?JpJuYRJ>wO-7ReyO>vTZwQEq`$ADf8MU6=XpF z#ZjO0_(ObmLeoLKOLdumRHjv(aTenmDf!V~Cvg1YFTKQ5dlJ+`-Vc5fwL5Kawxzb6p*DE~5drnG(PSq!7Oh zBrPb&E#loXL~?ua>Ih2VB6k;EL4t>~6|17m?Iq#bGL4GE0nr$5yPNS?geGYNz$`os zVmqto1p&gNZIO6?jrJeShCZiDJ=gCPH0$-;{n(E%?-A!u_>9=qO*;>Sh;`w2F3rfc+|t6>^*MxVlBA z`1yj?lVdSCZ``Zfi*08D`O*9U=*~m|e0%(#Q3JTW9!54LuC(3GqI7tNvV{W;HPoc= z4H$OqsNdxsQp=6ddx_en_^cO>cm^@jt%tm~OBdR&1R>B1_dW%}c-Bj{1-8_omJ`HS zp|8Q5$IXv1+bOssPwa`I5)zMLn^(1?Bv>Zw9te*jP6R2?K>>;fZ=cSu%9UPXF#Nl) zX#A2*!~dN%sL>$zNYp zOJ6wC^nNz-$qD+fxPF9`VyA{p(ADvpEGRFSC@0RIj#ve4$tfuCW zW(yr+qD_SFxKM`Bv>S#3Jb`KlS->c_k+X)g{Jp|PA9#w#g1$s;&;yDa{{Q4(x&M`{r-F zUuN$v?E`}(>IV&as~95sr(xs@Xl3U^j)?!ThcwWpC}-5qYk`Nd8rw{uMh0jt_MHrq z1Z)oTe>9QU?64A9*wY9d=rG53cdot90ODrcEZ`V;30&*kF{!ASm;@N}4-7G#bhX&~ zbks1qTO-kkT@-fvp?Z)#PHoZ*-|vX67Ia}>cj>TtH$WFrWZi=$DQE-J@gvEo(Vz?Q zugCItSa8+3_#^g2@iYw@mRv`Un%=MO6r7_WTKy!>lN&qs>Qy#Gbd0_Ss0-8LWGu+} z&8|FJ+30Hb-&9dOZZtVg8G$H?=lM;>J&CF!Wza{C0W$d^CX)XyjEiSS!vrr(BbG=& z!KyTWkwfgSN$hdN_u<|gIO^)S@g%;$Ze(0>BKF+(Vo_7Z1Y5cWO>2TY0`gD>(y+8$ z7&nH9{eO&sd@(pyk^J)PPMn0pXqJeR-fQ{kg?ODa_NpP2L!tODJgl)*O~&O8Z1j^{ z0G0tIu@eWm$KVJ=Kl^4(6%{W}iT!rt2RBOMRrWvCD4>`Ka32NVgH<9Kg*LS0`m_G4 z-459mJG|--q))ngqLjxIlE^wuf1!utyqYX^P9E6g09rMAQMNn}R82dD*98-8UFt*7 zlvtR4)Sa~2)S3UxC2vAdzj4R?V9(03d$HKKMaj zrKu|IzM0`!mH(>&*Jo4&=P+^L8cE_1;LMT95k}ZCTnTa1uSRl#?z`|imODL3v~8aV zrZG$8@6y{+G!;gJ=LVio5Gs>NEI$Tz#@qrtazv!-KlJ;T+PK%=J`FeMpquEr^LY0) zOwPb4;oFXn)5zy_rJ^RXAFm`pz)oG#45i<~%hft|)^qpV=ZdnBu>M{5t{d%m^fL>- z(^3z(V-w^;C6!WY*dvm!9yH!F@T3q0J=fOJ3W45yp&KTYtyc!_rpUW4xZzzB7#ML70g}VRUikgJBg8-2(yEH=#sp^#eES+eFiO6pt76c=9Qt6@M+TjIU zxSEF}=ix|j#}Q6pd_TVj^p10$p7H}AOG&xi#;|8=u32gQtpoEbD(7L`KvPt?l>q)D zWHq?N;waMQxrV|^5%G|_*>{t>vo|H5!aR+By#WIeW0#kUSJz1Pg1m!w#DUaeb_90bh98uv_j!OX5UT` zx}tJ6%oGuh6WppPaS<$zClXB2DIO36rnAA974?L3q2zYP>iIFd9#&|bk}qb&sY5hO z-nyk@pSEAA8i*8*cvavK*TQT1b_D&KRFlVExp2a!{3MhMLM-h>EptjRBcFCLHaBTpa+rk5=lzO<$#$%HBu%bArHp} zlfu^H-ZwCSKiocv@qyx&*UI>PkLsH(?x5`PTn7h&i8w^!eu?22D_JA*Um!StM=GPB zK(@C_S>sJV>&C2Xp`ax3;=|5%8jB~ABFc`qKl-~Y(V}9GcXot<#`w6=0}pW;F%4G? zfT=vid?aa3Kcqqb&Q1M#Yv1L8c_v(CrdnlYuxPnP7-3Kl4p3;j5p4tfy9d;T#GKNw#c<_b$@t4=Jx z!NG~hYLWz<65Tt;b5?`ASw(*zrUO;-r&b+N@%pp_Kbd->nE$BW#Wkc_1Qk@(_w92K z9C4)m@~ryy7^k%l6qyUx%8o0M^&3=-;NO|3P1Pq8~( zrvV8Eubw{*mtkLYsj*uiG2JG*pOgD%dIfZsL}W8h*CX#KUbQS`$gf5Vc31=FgH;L< zt6)coW1%x(@o+s9*~xhRwNmN&R*Yx8Ne1YMnhbG*LeA^9esw_qBU1?Zlmg@a95BtX z2O0x;7KhwBrMtQK)KTQg60%XAQ7&@716rGhTwaeDeD!#-<3RSgg!{#+((;Cd$}i|? zZ8r!_1*et$B?=(?A0FJR$7drJJXeC&*#Gh$|6L7tMa#u=bk3`F@Yl6IqWpR`bMZm? zoj$AIAHiS5;dvvpXv=q&h!ANdt?;`@-S4^jR%-3!&0W89 zdMN!9I5p@cGGEXxIoFbLoKw4poOJcd;U^{f&UQVgdxDu=Y5uyt4?$i$e#M69+4~_7 z9)9=W?Qi}k5ZDGCpY-dGLO^d1C!P~o(Y=nHQhqI9UDpH&p@K$Q``w!S#m;igw%Oca z6LK|w((PjzB*$Dx`hEhbAtDvvB8DZ(-x~1ph`IYxg-0HD-|toivmr9e;}THjJ%PV- zl=4*MI4MiTQjlNz#^;?D%5LigWn%kJ2pumeK6y8Ggs<)Fr}AIq{nH91D{8)YR?I$) zwjl$2W;zONuWA>Y%_NBzbj%hdUn!G`_TwR*d%Pl3kO^g55gybxO!SauGouZv_52t@ zcbJ-6P54K@Rp}uG=)ac_qzO`femZ$e^<6%jeSTmShs0O?wh1Ku!DB5gxDVf~eMM^3 zHarVD!Nni$nR81oFqMePo^`{ByBBQF4KN-Wwz|VC0mla=zFIB1p!+qD4&z%#E%wu- zU>Sl|^5%X$jiL(E#7I9h+~Y>|Mgi_ z_`eq4>)NMt1^N)hPLa3^S;FvDO#)WY9~x41a6FFQwm)W1RIT4QXA{&4f7A(EOn;?) z`uN^Z_i;#=P(V{K$08&6wl84T?E(~k)78z^v7VU+duqAcf`5|AL`?0Pd9v`)$Nam{ zrQOqM(wF{(EQt7TE)&UAAI-XOY4-^yd3!I;KM<5n}ZCG2PNhC)rhyXed(j;x(=?8CvoYl@ynf!=kKs9@%_R z>&_x3{Jq^ker=pcB*tZ1jY)q1Byl(UEo(xHjE~|7{<96iW-|D7qT*Eb;C>-b`DU(LiN-d(k({;#`8M;w56_dAn&X-kaLAwH;NcDg z+J(IkvSs2x^UxkX>T0^6=!DMFsCbK5xby+@^llEfTRZkIg#M>G!8YEnuEHMCZ4cv{ z8$XDPi{8PFH9`IYSS;S|Dx+`Cy(-hYg98?o!oNnf@G!$#C#M>7t$xc3roDsUp8^>l z!-e_!O0Bm5oQ&Ux8KDhZPC?Q~^2kss58>b@PqgLJrIns}$xl4&z2O1hvI0R~Q7f^h zqyfCY({!)H!>BBez%cl1mZgi@Dq5Qv6iI_AyaNrHgaJUqTh)J(P&aP(F6w2BvBrGnzwdeLNAZ8#9g`*(F6!xTAb z!w>a#`?G@dbo%uYsRp?))glcM|BGZNKO+WRlkL^i^3V_!FY?Q3cYcC@Ka>UT;iHcp z_KBb~e!0HVf2_-^mM3_;0|?Q&i+YcMzI8xV6}r9hS#`Gt;y)PRQM=O}(sXzZsRTP+ zsfJn?US@&uZM|>2z=dbvtIJd74$~qaW1~0MkXdb@#4anUaoR)&hXLH3pM6~U_}qn< z?mYZY8(g0!B;s&4>dvJKa#cY!qHp$jj3%I)<lUy;$nHB%@w+rnWl2;?SE&;+(Z&WbHS;yb+%y{4$5bVo9t$^LDP!jWvAr9~0RO-u z2CcOJxJQzx7wj@lx*{_$?>dMQ#g^y4WCKcBiqb}3Pd2zR9z2|}2AYn|353+cy$aNch z`$5%9;g%1Uj4a|5+X=YKm>k#~c~`&%5o5V@!WC?s|C#VD$EDcTqrD8z;AGXhzB-r@ zI{EdCz(Uw@Borg(`TW-jy)2dCbrSk}_W zu2w5jRd`9Yz%fs_O(7jc+^jh79@NjNb2_Bo5F4e%*P_AHSiZ`Xfm~m%#9CoHw<0W3A!6 zV_5H};~Pq)l;GY950R7T)wDl%mLFwN$MKDm9k7u?jH1-r0cBuaX*@Y9k}XpEUe>0Z zv$y|j8gBmk=4G3DMsj&TOGmLeNSMZM-S2DA zLDCx(D^U;(J#KGz2M^8n0*`#xQZ(FH>$I^i`;X$u09=l4(P*_{XJeBW4hB>68ib1q z{*x?UdaStyE>O8Mk~dP~eB7}^mVK)$33EHmdDfjYR-}P(fnEM*ZQNe-tiKNc;OR6) z{;Q?-T1ATQ38k{apIzBP`0Fk1UTY@gOjCWnZ#BpaoE+D2(a|gEVZxA#C{87=_kXTL zfFeL>gRG;y?oOtc*pxa}_QXB%0oWNY)@HH`x1pyZsh+{tv@l6`3kX6Tz4nj43aLU} z!B2r_(~S!~SJ->{A3nC?3j1RAUB>Z@*AwadA zuqXw%lHEL1tZVZa&xhUHb)@$8NP;#lk6)Pgpb1<<%09f*QQ+ehJLNU_vMSS~g)d~2 zc%@jOmkuKZ!%l!KKcJXYo;j65;=b`D+x)wuIOGi4dJypH9}{((T%8}kIoD@rPm%;i zj@Q=dPvG_8Rj^Y$?`{wL`0MyPi8j6>8L!)SmTPOEtPNpz;EZkKgl&>Xwxv2e^0Iek#|k*Q30sqGSRRrWjdKbqIG*?6%u!^7uoWYi2LHLccv zb!oxJekhsyw)IbwwPbL#V_RbfEjnK*dH;sn5{u$YTrce>FN^*M*qlsQx+VsNFHdKO`{r#n^M zeN#m(!Qq6AN-bFp!78?cWCC{4>%@zCI%--nlsiwpM{QW%{@%!xpL_SK?NWR3=U~2h zo>QUE6)~qVC)7ala#LIy0+;FEec$t7bD?fpPZ~htaMn8w#>H1h! z*!@17qu~3)3J1nrB;&JF-n7xJCTJDT;}t>y4|SFQgm{={lCvtpJYiQXpxJXXvDG=< z`dR3u$*bgh+tV=DB<7IRqSgWK7EL)glJRDCXy`Yauv}$6f%Rlj^Wth0#?5SgYhvs7 z!den&Mp`PMr|M}G%-z9^8C~_9Z}$_veSTbro#XXap~O*N6k_RwnS{bpoJXREG;wtz z5wx=Va_zTw8(mpiytYZst~~b+6j-E`?kI@9vSX~iJgnOX$1KeTL%%pQor2(k0r?J` z!88|^guyMJEr-h){CwRg3H=t>}VP~3QFEQ443k=bW(FELX z*=-HB!sJhHE_o2F%`Shb!AuZvgK( zb^fm8wz2tr=%vergre`GEibg(R^b;8!NSwj6r(xmM|VO565O{uMh-YMuNuxL;j(W* z&5WIRA=}CPo6!$ibbSC!qn>e(Fh*-p^UT&(2DghWhv7y?Uk0v)oA~AORWUdaVIY6o zn$3QBU=1_`m{cmXKVOF~!D~dMi%cP2F*qeS!S(}ev^2Kj98|kcfm@+jKFX`$|Hw}z z()+!5eI(W$HDe^{8Q(09-_%b2;EmtCCf0r4F#$#w#@K}70QR;gpe_p$09Di}R7ffC zgW@6U@_n)0RiVO5yu~zxbk7!nLa$XRQi^8En`^127#dFaj*bB;H<4f}>JxmBl@zLXvRl_jj*0sMW)c(13Z*s4@%yg>ol zP;ztKNy^FuYeqD%702C9iUJw;lKM`}Zg0r7M!)f%%WFSQb;r|p7-lSv3q7{K<3lmE zkS^^%rCMU^IMN<@HSi(a?HgpEfvne|z9i{8^1Y5iQa^r*+l^{-0GFXRo74mi|5u7U$rIK_)6)pMPBHQ%+p9MCja<}QlQh=~S{tBwK$cyR>; z;I=$(fhJ6;`zat!Wq2uCTo^^+x0nVyeZ|$b_uY{^))!;$0f)Y4<=iLB%zAY;-@$E( z3C}?7I3wV5=irmaS2OpTH;DL&GwB)MpATd7bPnxJ>`q5|+Ec4Q=| zcYn{cyZRre1*Prc&eNe+c)|#!JsR?-1H`vCLTGrtN8u6v{rSG<9Il8Nn#Y6J1e%)v z_i)_y7<|KF(vJ*Vzuzk0!-k(?UK8~@4D4XCi@nMUj-ZlA88BM$PH1(s`CgA}j8G$i zz|wmns*pv1?pV4^Hjj0kha%YMkC-F6;$dyu*W_K$RV8Wt8d8$Ta;yR0NhOt9{^)jp z(Qv~v)U=*tyn6B`7BJ=I`!ONLrs|15B^kbV9+YBd>(AP?)# zeRf`@s06*nuEd2S|Mb5P)~!%Xr!`Y|sn_;s=80(DUT@sfdf$5eVXY=G@y%|{^6=p) z%VYa>8+x6m?B;4wWmnDygZw{BR06HLO(vf;sr`rJb20&&#J^7;>xOwN&MqqU-N>+6 zei3V`{qXk%LJ=7K!=0Qhh(pr&$43rB>b2q@ET_=*cJEs?$T|)BLRn6cYm%Kyyl{f7YU;ylC$0iia9Pvvj9I(@`E0Z=}a^A%JvwcP@8zY8LjG#xsrBS4)D%>TeA;F#10v2Rna`m9I>UvxAAgFmCrLQ~N4?bGtW zQW5xX>5j;}eTH?v5udXT^e40u?{&RWJ38HO^||#lXxciMi{wq8F5xBvLz{)RP~BJC zKh5H}G1Q+EY@?!89$bhU#g2$_`aDRmQbmz7Mtvlwu6GsJB^wy7lD*W=2Tr=kD#tL8 z)?W+Zlgm0IzUa1>0#y{|3%a$}BzHbQqZAF{8O8s>mAzqXl|FQwOO`6f{4m%P^f2{>Otgv}; z=?LrR2k<{oT|Je6h3A6TAo_Fr6M2Z%)VEO zT1ABboI9T)Rt2PS*-R^|sb8~ocZ%;cr&*h-Kn~pg$v>fJCpr(YE1>~ZtSoEx;n0e= z-u|8?FrEZ~PKdu&ymiSA)p*;CygO{)P(kE z97O+*8`Upo{jMpfk?>24b7oxBT23H&kZ#ZO0{# z^#F8li;8o%_3N=!=Ep7T-l_kyzLd6<+u+vgp{JCw(E4eHqg*4uhUBKFnaGWzUCqAS zjjCd^ZAW1dxFl}k+6$Fo|H4LaXMYpCw@Cs3VdV;swfYekOME$$&$#?cz%)fEkb4&J z+j?UvmgUQ_b*j$s69{nX`-^|}*1u+*WENJ(t6Iru*;!cb1P-NRqlj7;d)JA*{y?p( zwSR5AB;dd{78CIJm?LbjPG+_!j)+kQjgze-LBmzdlpKR<>p>8M@EndZ;f#&Robsmm`g<#xDic8a}_`HC7EV)H~yj?uk%~qziH2|pd}N)6=;s)$m4oW;?DLySi+u+ zgX9r?pTiF8XFk`&2=JHrqp#rWCs|XIg=>|u>WLuEOY9qFM-X-;A{|amgc*3GfdnJ( zJTd*Gt|PEo9{ zO?XwD_SN(BzjZU1!n?HfU!*?2nY?@-7y5rX9ZGX1ITARm^gt)n4ZU!oD}+g&X5@{c$uK;Tv}AsqZ$y2>&*#b&47zdYvn zR#?J+EgQ==J>0aEdK->8G)Vq_2m710LbPgNey58A@$lNT)uw|0f$!%R&Ns-zg8z#Q zuFq_Z^k03Y<}Ub_L*6`yK%Uh&tUZ=Sye41C+#%#DXnc~9FoF#-Bk4A9aJZm!h^aA2 zKAsT;1N_Uk>;maYPt=otSPz>f{OFE)UAPph55S_h+(6;u`{#1XgsXYRV?*eda4)4# zfOJ^7=B4}qw8W&i(tYdQyUt!e08r+8Usm&5@+#fz-BdiqM8BN*vu9?#NlRO`;&)z#Dw zrGJwFKuZxP{|u2Hh-cQ}S-Zd%up6_q)nol6bUpN9IAd{?WBRcxg#|7-SGY-2yLx6U zIUdN6l50P6&fCwStlc$?-Te6+i%&tki)VB1RIgHBI_ z0QMeHHTXjZaL%(D3Dwu8y9+99yY=3ii_%C1t#>e#Zy2_>8y-^?sx10m{rT*n%j57n zA{M$!Y->#7#W!oLTlyeBUlIge5)>bUd@IGLRIX{Jnt7IPsPV%u4|qO4LN)Lf1bzIr z;SF6Ha);(m*IJEisX55?BSGxm3thJq!qYb`F}k#7R4isok=(zUVXdaH9t+8Rv0f)7 zPv##*Z9&TQq2l=|=8#OBT}uu`jFE&G4arJ5^rjb@T~xwAbj%q5_rtvf7)QXH40`*H zzh}1>EIqQY^G0N%TW|}|fid@WD!5gZJEhvSg`e6t z>2;jndfo1Q3mtG8_aT+t7+D?mU@qPs4EfT$mcw#hSpU>SZnr;n<_edUO(R?6lD-FY zVJSDA23|Y!7!M3F49d7xq`cT|@}^PK@)<}CtyXFv1h1^Z=$1Fm%FG(S3S%i!mJ|+? zSS?~xo>*?RgGa~#pc?qNEGV0+G|J+!hK3g(YMa1jE@KSeeXvb~$X8>N2!MW~llUTW z83s1}ZO~=ugQ(}~#Toy99H$?)iVqo$rT>fMJ6I{$Ks~ri!@hX`O%kpi8biV!CRYk? z9nC7KBet8llZsYN2A?U1yC<4PBUqT%FWzI5r8laGC zzC9k|I@spq3FB&2Ha1bnQZSc|@|^=4<@qPjGjhek4du}Rb*YxBJm{+!1m1KymlgPm zAx3lCLka`n6g!~m#Ehwc1-+V>7Zxf71{TT>H)VQw0)k$9jKF)8QYQ}OXEL9Fh&f_< z(|FAhkT=qK{QgOp$jW5bzf>4G6I5(Yh86o!rD32(^=)`~B%vAL^7xA{`E}vk++kH71bi|(!1Ls>y z6{5A*UBh8sZ#U2#&@+AW`L%w)01Sz^xvw2u>IVYZR~*kJ->JNzEU5L6dC8as?7RTsg0j;b#u^oFB&8y!taSVPIwmXI zTa*!>SSzUQPdEZ?S!EL13XUwYajf*aoPWUwz8@JfYG3b!p(@JV{$;bc1%Wma9PD4p z6^rpA*ov(mD-ptg>2IIL|3%oYXgEo#Rz=qdF)$1~#WaA>b17k7q(w5W)=ur$v-PGh z@@3p%9iBv+I5HNinLbLWsawvMft?dVC8AR9r2{;7H2NB1UmV5rv7ZICBPp7uP)DTV z@T>nh@`w<-JK7jAp}(Ix4T*cy%pNWvsCag$Dnj=mz>?~B^3VPc3ilrP zgQ#LN-#^$*@7H$XF9x^8+b6A52S3&~3xyOhOmz`(5Xoq|n|Ms~QEYtZG2Qhb%JZ6i zqP8J3RVJ6J~ zsz_Hl(jBXij57`oJPZb7!7Ki6B&um7Dtj+*IJM!A`J*LGy+m&@jNM1WpNw-9P>PYLt@s2BYMDkG!yGO&t2L6gsabf#tijjZp?*ucHjbV$_K%{CWRqVcYzg6g$F&(V(d>3M$_+UkUz<4u}T96N{(*eAmwTK>O%$J(RDc$oGVMP~3qaz@O%7insua2;ymmHDwCQq`C_A5iZ?iNJ4H zC~-4%mrAG8$5H^JyU0+Z5o=X=HcpQ3(UZZltY97kS|ggFLBF@G_g+u-zB1N;2N+ru zvt-~DLWxawoDbXTwzt!#LJXbJ(;qTo4M{_@{4o$%Ba&uYdK%EGaQpdxOr3R96mQ)2 zm+q7jq!AD$qy=Ot1*A&^#05nu2|+rRMnGCbK)Smdc4bKdjL zKRbJNX6`eybG~!WeP7q-+7J9Bn54ZI2)w>6zzYtPHg{uXtl5)S$O01q30C+p)i){e zvro_ARk*#2)yUm9GNX;o)a>`WaGi$3?(9ix)W=+dIWG>`t_6?aUNaG}#o87M2A%CY z@r%)ZL~V-T8vn)kjN;lV$3LG-0Y}%*J!>I(f2JRL$)ckpqd8$-isf?E`&yB*CLdoy z`WlNlkto9VK74AS=#Q$@1JT^bI?F?P;mht0dhf*?hS@~v#+#saUzicbA#hB8H4GYI~JMFSw+o|zmpibmuZ$_ zBu1bAEC|u0-XlN~X^?)(6H5+tQkIFBO3H1?AZ;1lz{MSLJL1Vn6w5Kef++cdHc~#| zhHG5q7qtJzKk;C+*C*v~d1ENO{DTn+!M?jb2_kdeI5e}Q%N}R=I~E2$5QZ1>JEe4Y zGL4xeDucjyxJyyKENdNFK~<1_k0-3 z81^#lq|u6U-U>KRu)q3{*RDWiQC#i1m8xgv)JhZgw^8NhF=LM;isN?isH#i47?1%sBY<+B)+WEi_;oqXL37Ror*~(Oe_))u;%m>-d&xmmmEZkq z6@%`3;dZU>?^(MzUd(2rL|uN}9VPt-et#i>{UL@jzwU(Z2(n#-|M)cf?=5S<+gbE0 zI&Byu6pkvLyEofIibMCAno(5x3I2RokqXX6sz=`Oh9tGxnVz70h5GQM_KF9`BFCUh zcoh#Q56YnC#9~bn{U}v7`u97ZJbyES2++cQ-XAjY*b&vd%%Azlxuqgsi*X(t!dxqRdoUduhQObc4xU##91 z+?K#_&&(*^as1hWnDHf&YDQ+9bAJ8SaRdZ`)71I&5QMR(i%?{G%pKi*fAnc*Podhc zJ6J0q^mKM{z9Nwro1pT$)4o*WCFH{SXe%a}iN6W`1Drgcch4>{lw7%}=OSqY3u}^kGXkM z&;`zDqJ#0ExTIk)lprn6dl0DG%zJp$(fsk&y{}TjiDPO@(szd(4%Qi;*5TljvFC-e zY_16RM6MyfAxcAnbD#H9G88*}I$xV$r z7%%@endGb0SeQv6v(u{tQm(5Gv_E~@Z+SeD#s>f!kG)1-5Tx%)wIO!pAA5QSUCeFW z3#^~aRFRPXFv3CD_-(>&O7~l4tlMVQ`uJ&}34}Z^A28=>6WPx$W|& zeb}F4{HnL^s=Hbch;4ow#bvBn=-)O7m0AEa>~}(Nju|Me%}1SElW*&^ zocl3Z$+y^HL&kSH>rP#i85h)Y^LIK(Uf(@gw9zZ~y(%$4)uk>*PhZ6Kt8gq3ytVn& zu(@1`HgSiyNjPuQhXxSSe2EY{iag4^y_pRsRudZC;SZL-z_=^Ia6M$nni^MbJaXUcHuDLqxTsOMLpk zhx0|GF4*{2z|Ds2F3WqOP25H(>%ljy7Xiv2NXDK{1#m(Fbbq^V2FTqxl*fMNOyWpP+Wh#@ zIGXEe?J}ZWFtB^jw5qNRhxT41HzGwORy$w8LNWC(iu&JUTkL*}<98xNH)K7tna_yJ z^bTBczAvUCrGYEb@5yNJMb|>B9_PTTeDU`z-y!?%O-Z}5kDergPb|--GCy2Qk!Cn7 zzI1`FY@xzrPE*1#RdMJ-SJZ?ul4J^I<+LXL< zi@ewvMmPfjfWW+R5gm7;hGZ6}QB{ zn~z51B?$T4jfYYUci(V^?v*X`9N8|27-2^%D~r>|M^%JIK0rOWazF9vtWI&%;KS;# zaSds&p8Sezu$1PSQ&nms41E!bf`{L9Rr|dK*~pkeqiTaD{!-{K$2Ap5kh4jgfX3P- zXEn*ay5ew(WU*PV3BEe6m4(>HbSALLWS+!Y%>UdiW(AC18LzayxG4sv3_hXs&+l$( zg-t~c-vHf>56FkF{SCaBS$SzczSz@ty*jY|*H=K2okrW4{hP~Rvgsr_}qm-g#T&n7607e25kaF=D) z;mWiLo`sNnBC1JnRJMG-m5^=p#zlA%#B=sK8tW<=aQ);0Un06LQJ(wU zPRNySTG=DvU77Qh=vI6l8uGW@59etLjHltS=mrU~Q$OeDmSHa>RcemvK|D8Yv-8Eo zcCJ#kl>x=9z5|Ez4`m!uBWu~JWacdUmpm^;j7&A1eZ~dfCg$;zZ6D0=fc^Q zb4Uh1YL-X4YBv|(2Y}m}xgJmMg1aXm`VVIMl|_V+iMGm>6;(Mq%Q614_kFW4d-vk0zl20A3u z6u?Cr{g6usb+pF{apymKPBTy&oqD1n$uE6;F<&H~{U7bf@W)C?*Khtw_I6wVy^)T` zZMrjWW3q)4V9ZknTiV2Z@(aI$V#};KCc3pBQryGEMgsCyA1}_Kz^c9DnG=FbpyclC zUOz*Lqr@>oLvj4OAy)An9SZjNyqc1A^~UqZet>%9Jd$<-Lk&Ny%qe}cjBCbZeW#6r zw2MC{-}UnA^sTN9ciGB@oD>t@ z7>>L0TZvrdjD(Q}(*1v{w*jT{SIQY(dJ=lSnSvx3vW|e@^5-zm>dIr&FzA&!X7-8& z`k$MXT;RC%3iSH#ecS{Xt8_1l_8G^vd15Tem?_k~jV`?`UXE)7yaRS08`;JeM=mFW7>N1rl_}N_DsWZrliDiKcg0^o1NC$Q3F(fz z-m;woZ$6!OErl6$fK)HP1U{$KyS#=U+pQx$Vp;piWxe+Ib|V1d~0rdF*{#Ro5)Lynb6^{w_9uXT|cG2QW1gphA=AQi(s zUN?)v(RRh0F$2@FP$+f$mvARgJcvDw#k-GDzOjqy+pFp1L-mgP1)!EE`C_fVgk&6# zXX5DVu+>DIjK@3I+=6sniGo_$xbnH~sW5GKt^DxI&J5vzwdmslDHXZ)S2@E1T=O(&LA8Y4 zR4YN`Tw#=433nVsj7{uAKun6>qNCj|eLx5qp67-o2dTA;(ob^y_7Wu)LRu~_rQX{e z9)5?cSbRY#1?e+@^tY`>!Pz0%u~%AqZI@IY7S$6zL8n#TT=;Pixm4QO^ z%CUviiM3L5a0TSi8sGv)4xw;#<-5m`0?p9t)eelz37pKI`^`Qv$Y3_iO@!D?64QzR zQ~JB~KqyYfUa&|L>U=3-bHf7q%-r3qL6R(kFB`Y?2qB7>$9Mi#F%w#5OZTJCdRwnl*8}OV?2C82 zO2rw2k<-n5k^^U=HTOnG<3dybZMy(CD`GOsG*&d|gMF>y6F zc9XIwb4wWsgwXy^-YVOF=PeF}nd6h%Y6cTZ37 zUtrC5snm;e|wKTy@l%Ra)8bMI*zEWB|A;{oZ ziRj)jE>L1sG(zuUAUhq_KC(j!dgRZHraG#LMy)*K7I8? zE-Pa4%Dx*`gN`HrDq11$Ai7Kdmx$MvW`Bi$M1s)f{paS#u*{T2i z;~`wjO$@cF zB>Y$qkI1mRCUTU}2ZQF{KE-nc{O!AJw;rQxumkBC#S6uk9u41U|5Cv@Z*DD*~ zwf$B8^;7;8UBdgM#LAEC_zCqWE1ywe_=LC;kndzOb4t*-F0sJC3e&0_2u|Nq^cR!K zxwR#LjkZR*&>-M$Q}U#X{i4m-%X^*_ zS3pWzqjn4{P>f$vGInme{OtVoURrKR029l>Q@ADJ>H+r)<<@CR87r7c zmzjPR6@xVBBYnT^pDd%Vcr^Q%c{GpO;(IPH?504R_hWx1utR}4=&{+90Cws8$n3>% zzB%s%OWz`p*kze#GiyeScg6$ip}@aRX}wE7+xzr9T;Ol$C5`5OSF<1ZK?=3L4roxAnT6gq!F!K1yYQq#-uWlkmN1g9d0K2z!dy?okcpE+qzjIDtOu zVIqI&;a<|jlfrLq!546EhU~{D2hvn4y%gE^(W*3GKG5@^mZ>IAB;KC}#Tw2$t+AC* zsYBvIC*_XQ*UPjXHY1PeZ^gSk5-+$q6}n7z8S_MFqYH)F^cX6b?QQ4LLII8*PNoK_ z_Ni<>fp-nUA=WT|qw+J<6!if$MG3yn(c(;zv*!nES&WT6Pnr z9CntbJ5KDpY7>3bJ)k*mC0QO(5~9M%VxHc-)Ivt^Q4Ei_t7>^nY+RkT|2iDqI}2j_Q!u#DfxplnpxNOq zKGYjt@}ZeIs<^Cr;yE{1Yg?E3Kru$N!3o?YO%GanDF!fxx1I!ZNE z`FPBjj^fU*TLG1>+#U>toGhqKi_bdn}io``9hBSLGwwNlzkB+I_^=kaKJIxHaX;>=rLv;|I!?K$L3zDCOp=j1`B(%4< z&==wyidm}h&WTdZf-a&pNC_!M=Kc5V3zgZvnufZvrYrQd97N5}^}cy8<q3f7|8sS+~mCh^tUZ^uQ*#85y91&!Lcs_!HeexdH4bMWtMItCJ0<2 zNJ9@_ey&UHoCfR|9BeIYpzYC#(n=CsGTw$}j}qa!5{-BUjd8R$m2RdFP#rb}916vysUSP|rTr`uzy+d;}5G&_h z_~QB94zi@Vbb5vlg)WMJiz70^vRqYPx;_nsMu3C>_egPeZCXslSG(MAk&XosASTG0 z*yaU1<=K;pY@#1UXecnpJ{09VtgmL}OoLy1{Qk2WYCDDv-3qtZHv&96ah@9pC~geg z@Qb>(=pV?IC`3IO^HH5!3_E}*Bk?ks3!IZrhkMtqyqu?;ER zZLxr>_4@n>3B1$4td2{MM?nxBX0t>v&Rs&?sQ7FnT9N zweLtd3SOiu>uSm90%t<@T;Sqt-$NPlsKqcMy%P}#bSe4VY|^wwLdqLx>H&TR2p~`z zJB15449)HWi*g<*1lswC@k?s&#QIaSpL?T!bN#kKdjUVYPU^Dr>>my@Y694fs ziJ$S6m-X*oYGdph!lnBDr0R#5%T5@aTf@U_t-9OYyJ{XlDTSCuz%9`i3pOTMGQ!%y z6l`kn4=c_`*+L3>?sq3)Xg@D6*Ezgg!%N?_*Yk5;mid2=WQgxQ4Y@)Gr50Zom3`fp z{S2bRPJ1CFk#TLr4OOm?YnNMz_VsUMim_W{gvrnSBHPP2SbfKa2VxbUM41*uiw{B) z$whqnk&-I4agH@SlyC>pfUkrf{TO2>-67OCNOuGy*UglD@wmvTt=S({Su{a_NqxE? ziS8?_^-89%O2Sf12+OYIBa31p>Y(2ESDz))!#MZAc~hf)%Wa~1LqUq*PDr7 zztQy%ALZ|6SXH{J?D)0(DRIY($Hm!q>)bQVF3=%!X33aD%qj(k+$I+NJXA>KgiE^4 zI#yL$pu4f8QKpsPHVca?Th%Ny#vv|mp(1(yqUJcvnNSYx51{2)tvuf$Ix|s!-nB|{ zqRvTOCd2s$p{JftlJ3l|ur{~_+?>lQYC6i4IV*kg$i$vR*mHg!Z*F;-CfOIs$LE8$&IlCR9Ct9bBuBK}6qHb|G-NQgOSAwKRyYm6x}(jV&r?U= zeIyv_(7P%OO%fu3I5-g(yn$st0Y?HwU;1+M#7I%EJQZu`Id%@Ne)zkoR=4T|zH)is>(53SjoOh}rm8=j=^7QY&J!9j~zZ|cz z>fgVWdpMKKT;d-N3@vEJ9GfX;ih=7@qg-ev( zziK6@K-=59_bG_ibykwkZp1jgGf~x&SIZNc4a5cnb2mYy>KlrNAk=H)vdPYlpb31J zS>6i$^9S9Ps(V3tH#+QNZkg=0cup=8>Rw$tTXi-T$aCV>XcM&*_=TbNZb19F(KGQ| zg>u#iBY2U%n(bu-`}Y2@kq0t+H|4%>>%9p86?~hCK%ua@`c{=!n=df1`im6a(cZ!b zGIPlk@Ac;Rwnur8C%gN;ub*8J&e_h6M8muth`c#9i|)>U^}Uig!nX_)V0D6gBIx<` z)!D<{)kN}zkw?%W-DcG_{{G{=KX8PT*u{sk0A~hzvS;CHVOaq7_`6LFMlnDMdK+C< z4FOoOqWj|CI~-P^_EFyB@seNbm8lD-6$d+cf0I0luc*FWl{bNys2- z(=QDqc!k^*A8i%9PC);j`7gX6we{@1y4RJ%5%%!lP$)3%!&5_?O2ipE70rZC*7rB) zF)J8eXji0}z?6BeXBS7^^iOg>{JJTYiwLv&tP{=?`Np7cVx-AZ7oCDqh9HBuLkhb8bSD*HBj-8^!(_Rt*~3 zMT}giRB~M|BXS;e>r6=m-7S-O-^yqCquKo7gEsG5?C(r@Dz6=Kn7rrZkPXE4V?b#x z0Jd9*-#-Gj?AfdNHv^#rWg=0Ok7ToScE4Q(zi(!LNX!TPKEqH-J2(=wEy%rm7yvPE zXJ=I^NvCr5Y2HJ^rWO66Wt^)Q-%^^E`;)@k0|(hBXr|Z8$M)PqO9O7@0zrlU2_tVK z-PEK7mB}To3ajR!K_jq3tY0_ky>#x`2p@cY>~NZG?ykQ2F&)+!_T2Vy}dB6GE2scU1e}|P@+?>jrIO?Lk zYwd+SqsxlI)OYXeQijEI2+QjU)botf6E5cW-f*cW#7F5R~c3g<7 zPk~ny_PY|Da|R#MG{IiS99fM9DZsaN&HN|O3)EYZ@V|^t+EavgWNy@hZPg*{%k?f; z<|C8sjE@Lh6X(#a)x$>i5ox4KueR0WygRZ7+D_UbcW%&u#l#!I{8bVG9TQ!`G6F}p z4w9_YW(U$PKWtm4cjdcn+B-iM201<2EAW=#vqcU&}$0)en{0UIMjH zT{QNg3zxGcpz#JR`l3plH(#G$xv9Y+&UG|ImBD&pemrd z4tAX$2s1NSpCHUMw3G|zK}-7l;OA@n#veR_-z7U5E*VnwhDlbHWN9|bPLIEK1pjNr z_1o&N!*1E^tUOx!i){taeb?M{@|1ei@$yrn8gRTjK>!X?ZF2;lIi*Wi+4TEogT`9V@&s zH|)o;xuIo{=whsq>9v!pWSZmbKdejDa)=^;@IqovTzU`d%&Eynv6T4ab-Zmth!yM!b(6u|{>!GcUQI%vYr5MQj z1Bf7q^{7Dz`W2k>1rQ}PfzF~Y4Lk-XNbff$81mkTaRIvE-GKF7&HXg8`>O`N?+Xg- zjsP@7E-5*C2U=uaFoX&`SI#Rv3j|6ghi$;O_TPDxAE$8~3WAdQGkY`KTCrhtgF@fq z(YE~RnRqxPC}Tenl}J7>crjyCypsAT<{7vhOK=Bn1P3}pQtQ-f5jUByb|=O2Hn+bp z{Clj86Qv$@GMIOida-zcUYub?EzDtjmjQpP@&9l|#y&3&dH}6~xC-kGq ze^NhBhmz-m^JR{p%&WMv+F(=H%s|eY>Y~JwW=ys{x!@P2SSl{n0QQJR21gksUuiRC zy9oDc^B}s|;fAwBcEF3bPt~B(o5w-;L&PHZ@O;_ZL(Eg-LGTZ&e9|#G=X^t`AT}_8 z+(9+vY=UUiTy=ftmSl%?m2 z-I$(5re9+!7xl>xlra9+ztHyVO)%ZF8&Wcf7C6GS@K06tQHhc4Tma-GrU_O75BRKd zFV`@qIjUI$n-Z}LE4qZ9gnEFtQBJv-1DeA@?j8F(I@isTK?`o*`MtebzRWk~-*iV; zu0<8T`$bQ)*MxmX@M|>4lE02a`drJj&4E{mWm+!F%EDxMJr~|ywvW8UrZc(w-G829 zT{JGV3x&yMNZUEyY1%IN*)F$}K^{ekM50jJ<&!QLl5LuiwjZ28HHLEX?r+f@T%O4T zQS{n_M_y`*Ny0I#JTB91I-OsdM*SZqNC~eJ%+UoBC0V@bxj0O?5*HIuQvCd2_ z{3H6crp~>t7~fIblccMzHx5pwQy{~8Ly3Mq7*hW0i^DULgTj1~$4h|MBcc29pULh{ z8=b_Pn}SnQs8g^KcxjEO`Pa*x9~I5*llgTJ5`F6m2RTrI8L`l@0Y*|Pi5jUFL42@E zPNA>4owa1Iot6rudqr+2vFPaEo-!QxPu=$2-0Gh0vQdy&c+Lf4So{tH%zD02*;Byu zn7w2D1MAY4M6H;{pN@~NuKAdM2akP{$-4c+8-7M*^8qi{xgr=vh>I*uXwat(JTQ{_ zvmA-17yfW$prs#Tga3*3d7+@y?b8cY(eM7pMbgI@^N-EfLHqbpwVZQkM(rYC$ z3AI#9dF4B)WFf?6Vjq(fPm-)j>jhs3b8cC@gNszmVbkQ78V0P*i{haD&*>gAOWk}Q zxW8+Fh1gSb@rJ8y37dy60`WH()|bC~WzM3i!f*9JPQDo;R4;z`XC*1sMy&BXi=~}a z<_I$3gZ?ex7_1Q;#HfeH{*6X`{@^-QcKsO0UGaKCb!ie>8<~cD)@PD-6PUQ4;Pz)& z=M42_&$N2UDlVeG`7EI)E~;1e%*LsFW&8=vIe3lkcc+a`0UYbIk_v}c;_vO7PbddA zR?r5Ak?B91*6sAxmn8G8`Y8S6N-!paNhi%PueuM8ug{`LOi1Uivc|lULEG~GoWE<# zw_o;Q#5BR7wC!Da#MFGp9R`@JBYSkXGMA{|CJeSuJHF8jj4oB*pyAl8!TW8+#qGe& zU+8rV9A{VSc7K|sj)=KTIUjmWu=_V&6~G%XbNT~cc&@-0pItHpz2kBnR9tQ90)5$1 zIE?^KRL`Z)rI+#KIO`XROZFL3^y=i54B0hBV*DOBOX&FpOq84#JjVD7%Ja`_WFIWz zmkL+Ce22r@v;eTsFc5k07Z53a&QAp%)9w4xN|8)`G!Uu|nfmc%Z-QTi5jI>+gH}JG zgYWtqIIORXbzhFZMTtUzjEKza`?g?W4u#&6_h#NXHO9e?A|W%-^1Klh(NT*N;-Ji&KHy?ehCPA-J^bW5;I-o_#1#(P6}s(Vbt%Ex`- zz9{jK)>Glm1R;LE%ciLtoUrFWA=v%sQv7o#Y>;`i$U)Dgve7NokY?i&K~=kG7ecdr z8co&5j1ejKI@(krd-7WDB^9_HkJh8sBKDAve_MjFVNw;t%MQbxA2QD#O@w@{2a4*z zgR0~Abd=AZO4Y%a+Oib!W}obX(PEIB8f@pyE;YTNkh^@g^<4KRu_!vqDFd;Ynte)x z`{cQV83R|$soCe@#P{7S=7iM|YLI@}@J}b*S^2*rY-40P4;gz7*q2VLa!~~j4lIkm z9p85jS$i!Rf#; zApu)HmWLN{lwpNuS2q#CBH14UCp^qChp)jM?v$~twp?T^th8k2=eRTbxbNj24Jlp^ zOdb*gui=lKDs)zw4KS^a?&tkbRhSq4%7uq3+n~(>bVlxpmh)5#VYIM!oXfl26aF-MkYUgzqlk;lq!=Sj|5={Kq~=;!X4RaDcIN z&I)sO86k@Fx~#f9ws{pA(7j_6({uk}VrR2r+CNMfmsLmIL!f7~;^7G}c1Q9bMF0vx zB^Jt4%Qg@qnvuFnf=HO|6q^*8exe5 z82=SMD?T_CK;ZJe?F^A1FA~lDvGFS^hXHJQSflr*^z8xf<^3z-x8K!A{Y4?MX#?hs zwoj`kZqZjhD}*Rh$k7n^eOPQo3zhTI;T)cHx&ObClx;jVD20YFv1>_Z^-u3p5y0^M zfcy@ymH<$`ACa%toCntT10va~G5f+$Z2q7=CveE6G{nmF+vn7Otu?NOSuQe~O8x@@ zvL>t7{S|>e(|%F)3_x(XE1i8*N&a!+pB{2V1mOj_Chwl8G^K8ek8YUz4gst%f3WvZ zTZz&l-|usBE~$zl`m?_)y0cdP@!EpFaLsZlD!bnS(*j(X-;PmBB+_Jsw?y%NPY{7G z6ymN!$Hjh>GJN|oJLuOl-QQE#{cIak?jy5M66jgFX|T5p#2sHOlLDuKV*_-Lz9>!D z5JleSAQvq!`b`CA(x+?o+qQyJ9dDi!6TO?xe;rylCra&om!sY)9K;p*wp52oL+R${ z_52@@U-V1g7byINtJZ)Zd8zNBH$DssCzEZ};yI-UuZ3Wy7c+?_@TB7O8T3`DjOHlak-j7DBEtmQ<(cw#@u?7?LLZm9W;KH}ZHH-VL zyFWRjLnp<_Yd(GvtPKwqKv9A7I+ZMGEI)RPqG&<0nx_k#!5(B@H%EaaJd-p|xFk)q z3#xgH`2F6Y&nNUR6dSh5=u2pfD*K7;^P&HEAsUE_RpI6M z4Fl5AXVk|+;|Ek6S7p0?EOTa#<{KVXZ%?c)kp~>d;ObYWxN!DI544~@U`?@~KZW>% zBXQmICRhWsA1>a17mZ%vYbR<1a^(GAK(Lhyw!U-L8|U*ZNxijsv;rz>$-G%lz3v(l zOylvi{*^t{D>t>B|Fj$gJH=o2*fC1}Ci7l5%xk;Ox{sV&KSCtv6ighb-m9aGj% zpo2)fQ8`J@Ob7xP-inyHu7L`)D^Y zK=Js&+uoEd0tKz4e|2O`>@ez2dJxXZXlxatt(E#IE^8UUo&7LUVXi>#@#mgW(wAm; zZ*f&O^($NaJT?;Og|Y++;&!{g3Q-W)esGRF)9y0gFJZj_U(RZYsOSJAj!80f&>+oe z6ocaOH4q>%j8aNi|1H^t86Hh`Bl9!Yf#LRO+Kp_%0SbXLo7ETMXKp#1v+6^|&IF45 zwiL*7GYj?PE2@7|Q})s*6Swd8;^8=zL;g1PLQ0D-gF2|q*M+YUm^KT{pSks3%i!3iK3bukyp1Og zWt^NBf9gdo`QM}xv$L14Z375%wM6+k8|V>Gg`eUwdS&K!{j`vF;3t<@E_YtA{WS2f zCMz(02oIC#ourO~#ch6DKgvFZ%S?0U;^}%e9cCdEHUU|Rqn+4l8R3)NJ1%XJ1IYIA zoZL;kj4_}o8v`dF9+=}RDV;y4R zGqUAIp!yOxnS+To;9%9b^y8O4=GLdag_#LG)RHPQl_KYp+&66GLA|lMw%?TQ?8go2 zUKDQ+b^Z%)>_wSw0PhC=;x#;b3^J&o_@$AQekyNIoq*r}D`jMFZd+UZG6z-=tDY3z zDTBKCz8UG%AW3DBIIPIlXtq>6KA4d}QJ( zW99C$Qo3pG)}ShQV^+~+(?_j@%iei@`4k>%*a#U zDQ9u>1DDy)ev^mU`|iB250p=Rog#j4lYSAARFM^iN}lAT%OU=ko^5T$s2u=*c%3ngbT0`X*V5X)oLHf#H1+v&4}Q z65EyIQ`Np8)51=nF%f_Qzv){o>{y0t9#yg5k}JA?&-9p=jiYxT93i%LsH^-q+)^)S z%-{JY`WJBOs(!PZ6bD6a@2xy8Ivs0e5BxRE_Um4c`8CI4Okut7lg_td^ZnTSGWA?R+XX@=W=|t8-V+=K^Czhogb; zg|{M%D8%I7!E+G6SmF!-f!Dwb?cyitN-fl7PZH^itnyC5NIr2iGqsCb0eJ`84L4=v zF84fVA$a(#19Drfyt!ip8ZwDk%o4aH89~w0@dMguC<0H>;Ma3 z$bCWots?L)($BtKB}|PJ6S!l@-UV)L8gIN@xlO3BA*PIUB?Wfoue~?d`lK1}?DE04 zGAzmXzuju}ku)otiVbd9>GgLp7_$pp`J#L!3OMyGka$c<`aD)r6%0TAAad;-2qOgB zD7#35M^DWJdY!;4Ix@P

|K?4;KK(9Z4o5zza$t83eG)YZ)UrXm}H)B4`n0Niw? zrWZ#XumtT7{~qdq^F_;V6s|3jHM^5!_|0WTEv!QD6c zY~KiqVGtqY&sQK`H1Pd;Bs^C~s8U=$pIYyRYYNc9nqU9`!aV5z6t4OUT zZ`7%8Qzri-*9kt49-S3?Nn(o4^?{hr0a9aWL33=qjPKc9gFofJE=|N)e-Qg2-@HS$ zQ*xv|fYT3np!N0qhZ+O-S9;>I3RIrQ+@t9>&=W`PY)+CdlMX*oHPQ~ZbH#fbjFVc4 z7_3v2l}@-8$#}k07+d^_lomTT;w+ab2Qaf-5WP0MXl|Oo+6vuPi3S$Ifm7qx9)?)G*TFnqtHh@SVJMvXS@~F1QzUIwH_$3@)~LkvR~( zYyaQ?vbn9EX>dwU-uOlFK&VI1?yJiHZLrLq2%+6Ir9zy3#HWl z_P_X9lK~=4H-iLgvBolTn1%-Rpvem+jFyy9t(c`xpJKZN9OUa}Ci4s(C`b9w*yVt- zNtAICg&y1G?vOz38|jsx(l`M?~1kno1rS-!rx=Eo#|EP z*22|d8C1fU2%A&KbB;f)i%8QJJ-X1LFrRWD*r7zz1TDsVm^lJ0A2$wng&exvSiJ{z z)v#x2)kkQD=l1>5N^KoazAn)#~HHuZBB)|e|; zvJZHm1NT`shGxcHob-zaUZAzn@;u7{>Ridi}M`9-KX%9h&nD_2Pz@S&eW$Z62iBR@hc$-u70xlfY48UE!-Pd1c;t?=t+leI%g0 z$e%|3{YO{=J*gPn)c1%mWz-{}hYikp_-<;|v*f{qOWuMvot;Hq`7Ufl>EcDJXbSR1 z+3YfYyS!4Y&zwB7wf&7~2HQv}ZW#PwXRO3S zl!+}Z``3D|l(seZ=?uOU-B&3v^0#4Ht8{hfa{RpwpZ@rrhbd^pp;U2vp-RTo8(G** z;`EzmW)R3FqdX^CuWiUoyNMeh>(w`XbQ9&(P42icEG+WV+)rVaNX_EZ2xoBovJ7|2 z&8V+SlR+8HKI)J`BefStir&_%H5ju-ehiQ7lIHjH^YbTxS==-)hIY64<25b2R0v>c zzHgxHB-IMZYf94Z`ETn%M}JD|@o%r*NRATk9;Y8!W`k`D9~&1u6z+q+eJBEKvhsL- zDC*<9Th9Nfd!a(1IUHiNRdh_FI8=3`-Lqu!(SQSdU5k({h z1f&I#Mi@{*x|D_)Izsq`MiqVWXWx&7{5Nl=7y|K?(D{Qc_Ao2zp50Eq}?nR}M%Dn@zxqG=+` zbnV!CO=NcTnzkg%p~GR$=NQ}3^DFCcTZAqIu{y_M?F$coOvrWYoF7`dYy?bZe9;z; z)@73a2{pTpX~v)4u$_m=mpM`SgTEGGQ&|4`hpAi)ayo{mQ1YF`eVsYE5pwKWo;wtK zdj7(Zc=&TSC7%_@v@!}ys8|c%sV$n3H@6cl{S*Eb(`d+-V8vS#zjC}xbX4I1$4m(g z`4}`K0OFv>H_%S)$dnxeo=&S*1q5W-{btW*WG9wxgz+5z2p>C()&wrECsS@tn}zUB zM%n)zX9qu8Z(e2gEd8y2XXEmyAXYg@ajhC0C z=GY5p6mP0s?6=9OkUpR8-rWp+ZX13S@oYC8rdsL!WwN>LyJ=H46nBqi_|^fAJ7j@Bgm?s zZV~#BpJOl6A7xi-b~Mm$<{6}!YM?C(teTnrymZWeqb7`pEBc>QJE~gecT#iAjebNm zg-%CM?jVE{-)K9@8;`C0HDgmr;!@98pUQt-ML8C-uj+6op==``cq|4Sz+*=7l(Swq zv46Ys4H&$kHT@H#u=-tRMbSC-GTyOcx!ajj4qHF$oDs5K0phI4p4W5+Vf?k3$X&rOjUiZ(Z{D@?+fOa=;n?9M}w<52JU-X}D z;ahvoFKliI1*&t4PJ3(dt7CG0qtMwy-~#_~=HbX*?7+nyec`%R60^5bxx+E#fmGmV zc(6*u6Xwip9V2m1HVa;dnD@3%+WH~3mM&cx4}9IouG!VF^yB)>eh=5!vp;jzED+`Y z{-y)Drd9Uq(wQU`&DbpBMQ5-AoC2nyB1NBcC>__H%TP78?|^EXk+I)>O$za~njPd4 zd!Q0Jt{`g!bT|_2lJF-^)8~sA32x$3*~fn{(g3CYr4NQv2jVA3mMUm;!pfBm*{ug2 zUjf7NgKQ;qn3xTQ|qMm0bmZR&t;yYiU5FqKHkCulvh59px9g+iel=dAC)c z5h9eYv8II51$&pz+=BO>E0}b#-7x6*L-+&?^mDe7H13(#W_{NU)yzeNH!N?KC}s1- zI?-sfP!|G5F>rfe)=xX{8WxP)IwK>SrHTI$$zeEvw;Sv6kgpx1&Rf{y*aGMw%iqiI zWE|L^g<8BYljD+Mv}O>tD|}o z?y7o>{p*&DQe`->3C1ROLhMKqB(8o}JZ2nDwxB33F~J|LlsWEVfPYP$ly9CEb*peN$Xd!aZ2Y7DDnVmUMYh zS88_zoqOi@;`kn!Y`Jo_p@Y<)0eB}fZtQ&lliT&;spA1>iGU`Ee};sEJ%iFe6gx-S zb*q8~gL1Y$QqQrjkhUSw_F!Xw!ZEr(K?R+S>3-JHAe~kVL@>SehfN~TKXQaQMaTc_ zNzxZeH-a@950*Z-`9}@}S8Sa}Ce@@}PZ&U7B#j+bMXl$2yal+gw=jKSLS-fO_AIsS zH{CEFyiFofk|R9Y`5gm-*14F=x3Y2J!p+ixmFhy)xDD{trO*om0QlWzmhPKouhD!= zWjd8l@7x&Wq5)Io4@3R$yhZgY`ABApsh~{Y23Xa#!W9PwL$mhHCX=$VNeyeS7xy5G z`PCTUyv`6UND~pxakMo>eVHzl(~P~xSu;HPLP6X#8{oldv=h1%!!c`6Kx`qZAU)dt zf@`?Hi;)K2t_26qYft)-E3mE`FJYUP6AQR;a!$5bTP{$XhZ7V)0tXNY^W4dcMQwd? z>#Unq702Y#P`3#DUDT$ zKa?3ywVAjq-92)}=W6+P?ZNvTvg&`=RyASfaV1PicP5x5qXM*pj?(+tNZ(n z;9PF#+&*XJaq=Vrx_np9oPd^c|zjSKZtA^6izrBd{?b$Mr zjhbLpRg_KNVWzeu_6Nie=XSy+NtgC0{{TeS|w#m#P=EHY4$|& zR!ZDku0{YBCRY0F=VU^6^q%t>1h(cuH2VH!{bNabO|mCBZD4k&D;p4T$aJoLitl+& zLaP{cN2up7{YQp~SFd|Aqhy4d!^s3k&ho=W^RTqM1vVKSf}c;!YpLvX!Pi>+*EcxM zcxR(zglnmOQMQcZsMw*V`2`==IbzSX4!-xObnwSA?yR;v!SgA<2>0*Iqq3j8zLsTB zVj?EClHq&0v+AXBXVoh~bBOVl6s3lBSjUO^*waXAavmx>BSq;bg=>906htYE4PI?g zqTxB$7QoY9+1*1#Ws=c`63LCtT(7*RLBLk#iG>_3rv3HJ;`AX`A$hG6|DraHj|<5o6tKFOj{O;@CuD868{Ip~w{3=u2ioqOuaI zR$6%eH+vHQh{i>(gh_+v&F!n6@JCzE{^($E7 z02(Mz16M;5K=={^FFD|U<}TFrJ3>(O07l05udxz2^nXGxx_4FgPcr+Evjs01(~z}& zxn?Nl!LF8@YQlqg`W{B)AjpsjUl$ejW}eNO*Ikm6yY%ElJV$C{EVB*y3q&0q6QWu{7I;m_tAP9e zl-aOYsBXSv$O*EIq9%>?M_F;?1BzG*k&Vic};i?%a6ja`}rr#YDr1I2q>; z-Mapyv$|q#xl5$%e*gUd{F)&r*)KmZ{DSqYGtG53Ezb}BZ-~6G?#a)FZkF@vZ9JBw zcU$uA`@MJX2K4XklbH?CnKa@%^Oe}is8yfCpt9`x{2R)sXd3x_@A?d zt~0K*k-Ure#X{N_<|oB*AzW&D$R~%jF6$y<^zrsxL|}P@Cm5C&!(c)m^>8Rl8bWMw znY%MQz&VbU>O}VBhR4jEA;pwwDtG_5)XD8WE$AN^xYC{u3TIMNboLrANkwTV%+Z(& zRl`MoA&0d6QBgVJ*b%Wu1Y)+wR*65YT~nanW)NW~cT^xJkv8D=3!G#ZQP{!dNIp>` zJ!pOscdhtp-Ur+HP^|n|5B->whGkx*UiA%UKbc0<(=W2=#R6O>df6FpLHi8#iHLiv zL)X4>cybIX{llm2OHj$XcCj<-zQ0m61GE8cMlw6S0XO6MxQ+XNh6yelG0OgaO2l^A zUh5wnF|!MQ)HU|R-@>faHmz0sp84Qw1(j^l8zd92qLR3%Jc1X)bQenjhN5Zjdv@wx z|CZnKj@_=J_FFF}HmgJ&m#>mf^oXX;o>x=&d47<3nXx(tg0&RUkU+!9E&fRvVk1}XC9=*$^5OS|MXy2_2svLRyKqH5#rm zk1yy`eSbd4M64Ehc1Wh-P{c1DR1><^{Pkk!*tNX|9%2cc=4_Xa1s=$aUB}bq4iKin ztS(}HBpZ_;J*20hyl1bH4n07n6`{Z&*z>v`=Zq98(7I-wP8?=e(h^E#xAXKp3F1Ax zBjPaCvtA|b-4~tnq&?6<82~i@sCv>dAfeWW=zcj@-AVXiSp(n03)0r0qtffbyG>sg z4OfGP(`swCN)+Fu>Ih+savz*M0FG;Bl#fE;8LU@2!T@|)0(O4}=;*eSAzCLRS2fxd zd!ybKL``eUP#o`X+K+=_@5VX)5Qy`2>*M&Secp8ytuYz6Fbje$?)`f0gR{9(?fW5q zR}pP@|KZF?^PEzbgDVEQ~_0FDq&pPs`KViDlH}UV+G$AS$JDsQk^>6n-6=6std6@$Lg;O-4 zinH;{G#oQfNY|n)-4`*lkB^~CobdfhBog@G3Eo5-)p106smh3R8fb1XbeCZ%ibb7! zV~a^%_zV2In`3^K;6jejNIYX~ML*DIhmc-}JZRvKvX4!APuI;wL>nqHHRdRWg9dOR zlsY!SDE@zgCE8$x+^j4CJeYKA^CXNVe4uhiz_0LzR(5Zqn7*VG zmCCsGIe#;{(bNr`tkr@PSPkkMqgxZWWmiy z4{fX|x*A+&Zh$?m3yJ%ANHzWYE?<`G+_&u52}cg>8U12*e5|fv;^n|$cl?PK27OZ? zcHv>B?)su$H0~=b_d;KH<*N9?rk>Wu^3VH)&+dpUMxO@?J2!4^cshn;Grt8%h<roe=*grw;3K1%|@EwNXS_joILrmzTVV<0R>i zEo(Or+x|0)7Q>T$80|3TLSat3rCVe_7Rj<(5#xnRM;sI~pVCV=j-ssFX3kBYlAKT# z%h0E2Dv<@n<4xZBeD7v3oV#8nd!jz49jttv^Y^MPkGPXmi!2uft*j)KHM5 z)oL81K;k{a!^o)II)t7dg;F~!5RbGt+m+%Z?{kQAH@j{kb&2_mzOen)>TmcbxHe2SkP?Hv++N@+Mt?fm4*aGpUW( zf0m)a?}7|<4n$bvfv4((sbWifM4~R3G2Z^8?jiKgbW)KItyMn)*yqc}NF~;sD0&Dt zz+q79kYbe?5+wH862MIcVs|B^QbkUoAk1Or+EK}EFx*}XfA(PC8uxnijYRf_M*=Z- z&=R;iec<)EQS1gTT=oD5{lfe{uXEOD!LH&Fx@W%tFo~SW9=mUWDGGw&O!TGaD<>p z_Tmh8ta4IZ_>aC~dQegx_q3S=aohaj5*JhE+iN36Cfndm1N_U~J9`hWkVOK1Y>%pe zwPQ^j-{sX*y_M&_G6=g%;x@Q9eQ7*7au)3h*bxS{TZBKnd9CdT{M~J*@U!vF&g#w# z__x+;QP&VQ#gsOP*s7ubbzIe0Qgd?Wa}LFeWYkEVY_j96&?S`zVI>S&#BOWgZpE(e zRwSW*8#xbyZV}ZMu?J-yoLD4Cq6NGtet!M?H5;#5wJKZ&*bzYXXL`A`h z{79!8@4o)%3y$%G9=n=l0R;5nzPkx!Zes1{6lLzjXHoBsTBspKwqcY=@uPhdc0c2R zD!Wa&E9{eCd5M>~5~%?p^g6t9se=%4tvc>d4O6Gdh>2{AT!6tNTy=IfW`BW8^N)~# zahyrMxEiiO_qC3DgiB(_7i4hq3GvaiY;z8F+pK6;sM3i5JAaFvWxGW(WSw~1zwvq7 zX(Ya~vcELwPicDz#P)$v`$(YHRbAtDL-Xj=opQ%Fqfk;mMY>eWT6cb_NtWRwzXZGT zKi0_a>$nqW0yPlg4GlH`oSo=!qIpi2SI(zJNf=eU;rrf{5_{%XPru0Z=q5g_6kMXP zY~VC|JK+pdtPyVnfDP z1|uxy5c}7&ucIdiW>AO-nx*NW6FCQQw&_iF#?cM~m&xZR<-ZVi?ou^U`>`vDN1qL^ zcIy;}`j7KZh#Z@jE_L#z z+9eI>QH~tp2;Xjxuitytlk7SR#go6 ziRs#;HGkv*`e<0R$#znfMm%jlcRlK@9UBnB)V5HDfp_&r_Pa~1tN~pV_ecI|-Whp_ zHzy1n>Yov*X*0d$6nlMsA@#+dHcHH|qC#K(j_toV&>tlo=*cClIb}CG7^YSq{CM*B zki1cW%eSyzlp-qla8#?IIik>&XCF7YQ|6rcEdl>|jE4wIdPsItbaAoO z`;dSZG+bT!jV!^AZW*=3^^e|lG?xS;gm40-A~cC4FT}JnuabkB<~bDyubBn9=u`U8)UhXSF@JQf z2>&dXEoJg&f8t6M*Dl5c16%wD?d;DzOZ(L%w;1mKZpz$>c-j|^lgvpI34f-P_{$ei zsU|Y2|3a33B*9{%lJ%{+Fd%v$AdQ(=<_qhglqZY40GE? z(9AO(NExO4o9QmLr2#}0vQ9$dIcV!tFkGm6jqeiAr&>EBL)`6$z1D_C!d-5N$5&~x zu_c4JU6oBpq#S;TIf;zrg9}f-##&o>o;?k5 zJaP+#aVFP@#+0lA^-Ci_J%H`Ocr6cpEHK49T@K1bh1iflGxt-juz8shMJ&IQWScH& z!nv*#@ILgnF=QW;W;U_Vmt^qwm`2I^5b4$HPg10(DO?BV56mAV7Pofj!$KhV(U)Cx z?A?8Ge<6Z?=BKw-p39v!c1t$uJ;otOw}s9I=6Ct~Jv{(S*@PDPR`!0gBTqyzu}HOp zGIRSQQ2zH<4Z@QpYv`+g2zL;yB<@@oKu8}P{kPWZ2|1H8sdH*Fs6`P8LARlosvhP; zT;!~1N9+|_BGQ~yqqvVc&UQutt~um`!B24hes3cQk1iAHNL;fc1#AXRgk@)s^7ucH z_S6@h`W_Q~`*|^Aa}`!&QRFOp4De5UJ99m@iu%^g509x+qVu!mCmH$Wq;;-_4gNSp zBIOpR4{1#nUh;NmtpB}x zVL43(m)kKfKA!!Oye|q(d26tU%3L|Lew(FAgXt{62PR-r{7lmKm<4fz5yqF!)+Bp>w`WDZ)gJ(QksUXI`#M!wZ2loc?zbc61V2=PC20twuO_6H-G)=86 zCRp?S5R&-4t}?Vj!%y+7sxI(;5dGfQd-;_y_ZKe#*{&LNS_Z@3Egn4E#T&lAu|EPO z>B2<+H5!sDSU;T&zpyr5sA6w0d%w}u^T58axg#Gp@WG&c6UGwk!Xh)(VxsSOis>7< z1@pYQ)WQ2CUJrSJp!kcG3G#W^l?@#@niGvUOSBljEzc;*Qp=dcFe*OVQb3mF!GK+2 z8{@f?&9TUo&cu6W5k9bSX+*Y*ch(Z(^cRhqF`|wux2f0sn$BUb$1BL z{Nw?y0Q&5W#ia4-T#+YloWmb|2~tdbITQngk62qO>!}W^{V^$R^}h< z+_~1W&a;Y$q5jKuqxJ5BT!+2?QdfQPYB@gF(l``Uk$PI=-0TG>d`sr)Z1l2Z(nF-l z6^Egx5&`DKS#Cx>C?QPy=YP$;ZMj+Qkc7e;s>%&q^}st&xPjDB6$3t%+vDUh%{Eiu zI{%?Qk+bL33T80>-yG?VfF!T=C8_B%m7*cCzf!Of>stNRM1SaP1Szfki&d#OHpw-h z^*JZj3ytB$TQCSI5@LS=Po!kM!X_;P?dK>#?c3N3MXwFPXD`mR7T;`pK<>fh8P{=9 zHpxHIcHIi4FCJr}9^`Li|6!u-&W->9ASw%6E^ilaNC3oI`2 zgAFJwYc(O}Q_H5nfGM@I4}Zp{^6FRS1?AOi?8rAgXusb6@Fnco&%{b^d=uVZAE!$> zaCB#`tWSWR>11_Lhb>&bi@~{xQ>RUNf?5YvJGL(I!K^-;v&5S3L3RCx&Ca{K=sJiu zR8~;#hcQ)UpFh1QROxEci=pE4Z2rj!k6XSN?<)O7MFZ_9elHgsggJGwh+?~8SiS24 zb%4}{qj|dQTaV`}s_&HRyr(=1ef+CXm__c1$@quPrgY*0)s}s*axq0Juaq;@DEQU0 z{&kZXiQxHAS($qlY2JpX^8)?ma-E82sCvFizNSxip2HPSqh$?M=54msr90phy5Dd-6HSt=$ui~k=<{hyn5jFw5} ziN~N_Tj*^44bNj!El4v(c## z53dw$`2vA8Ej2dW^u0-{3pS4M?Sl$gAw$5i&_%}+AjXSUjf?uocg3?%YiHrm08!5I zkF7aiV4eWe`BBpMWt<4*{wr3C5G68sqzSgnA6RVwtV(7fcK2xEkrw8OaVl<&^#!o< z1al@*;t*S9uwX6LS$my59R-%{JCtjnCMoIbXJ3E$UjrLC7GuSSq z2|J7>;~!cjd75)I_?h|R0HyWIqFNA{&#b%cA=M|}_Q%G+`96#JP&fDNW*7C%yXrh_ zkNyUQJUR$~0n@)Out*D_qVeR4zXrFt|m*hpMayxc?Iq%}zS zQq3#Vkhn6B^X`j9Xi9Qxemgb)Nm=&cl-UgAHXJC9Uqia0Fqcpy{+rX$N-*Bgpl+x* zflu2hCOpDSuyCDjDSt(;acTSXHqL%CIsb+;BxOo@!h125T%Y9K_tlohQP>Anp#Yh9 z_BvgjVaMn8qTiYvb{WJo(7<9?Cfn=Oug4SpU=$o^4`D7S`@AQeWA-Qc{{7hptcF@g zP-JNKl?TBE`q@84ezesh3$+{&l27PX+#~xZH@BTbq5OzM3$)YkL-G{eRXHGxl-pr2 z^{4nk)Tvv`9egKnDbos!$O=FJ(U{2_1kzQi^~*RqHR(1ydPz>t^&-x*3J@iF_w?C> zHKP0c@p&42h*1x60Sk&8R4&VcUQHPK)AK!(C*+o3YU;Atv$5wZ@(0O>H~?BP#62C# z4NBE?sByzUa(R3WdvHRxwD`tNyg*4W#7VXH=w+&e%G)k_@&A#qym4m~%jS4( z_2&pr^^h~=y*c3R>X$tUw4K#^mDRIf;jF*DXotqovsL%=ZV}^pf>JsqOrWzPKhBn` zJNM*^N>}lRqJ5j3cio&5>E;AYf)@q*QD@9SvID|v!E}!FMsw#Qgb+rA6b%2arwlWa zbg;Fp;?u{^+C&rPZ^nB;ls~V&1-TV(E0u?ly zfg2c68Sr6252Sg~j5Ys15#2je^BPLr{dj7FJ4kZ99usEsz5wIOY8AXECe3th`1qly zt@QLW9yQ;PL)MWxgE5JGU2-B;{+gDX&~=XM=xEo2sDX>Q7{)*R-jVv$DioolUhhYU zqxatWt122~H9vGAYO?BT!_#J)gB6}f7%Trl2;34Y`bp(?BGH_C#0Eujxt#kFMhK3x z^}b_uHfMp_&zb#3^AH~P{EsF{ zgqvALw1C+J#UOJc#7~185R3$}ZGs)t$qz9pqZ!ei^Z%H5o7HsCiXQ_TtA@Urzyy$w z_<2AeI05ZVh3A>1PF+3Ku?N@>py1(t*#1DM2!ST>eVYD0-nT-dPfq}Zz<&za!(PU+ zl6~~{9jWZMA;E<@?Vapr<`KEtn4MVW-JU^W7PKFanKW)|iP&&Y?lA_DK6Hi*`Fb8_ z_Kv9d}0iKAOzym_B8F$%$D|gVbjok&=y0F^RU#gYatS-&|Z9Fv+f+n`?ga>0`3LTYvu#R?=+X;a?AkE|o8X4u%(#P09&Cstyab(Q;{Ul#3& zA3|W$oxnodHs#JFH0yZU^7tJv8%{7*(kMBZeAhJg)rfo+L52jAj`#g0CF-2nK6qhA zW&=R3rSlv2>=SkBGw##ZhK7E*#j-qZr}MhSh^lR5o=B$w#8Y)|_0!u4`ai0yA1i&L zfjQDq1IOtiXce@S?FFda{yE>lL;nXbpi3!0$vFHQWmD~E&^2&gz^w_f0bv_kCm+sk zFffXgI&E?CWyme05(}Uo62)#XNoI!+mM|oZ6=&*MFpnvAqi@ed=N7Jff8IOYuLNn_ z0Kuir3-?&?p!5Yp4Eo7Z1ieDO_~xyd zrCZ#xKf0ZPiKVwx{p7(v&>SZow1!c*4{1T1082Y}-*SK}4X$T>gBOp1JjSCsqi_jl z!W|;>A{cu%OS_~T5w$q7X$RTiMj4($HG#N&hXoyvck8z8u@32bb2--sj@pN{R29z+ z3vbNfc5cf3ySLPaEBJl8v2ix6#PGr=8_SV#6394GH+Gzb^kl?$2nEsP2AyIyoJD|M zJbHfp>LU=~C#kgsOc__0;ZHzm7sCt$9INA3ZhF-_rdOv|4GO^tx&8^&f64SERvB?iT9S?nih{;K3JWdMN#hg#Ja!BM+0ZSnlAk(Bhk%I zIr${(uQi@kevG2s#h@*~og!)AD#S>{bJQyvyWvCMz<0wVv%A||D`AmWk`WBk{+HJD zJ-WPTx8}Xj_jFnMN*Tby`SP^qfC?gWL>Yof1ojiG6n=cZ=Y2!D_IK9UJ4s%w#?GTA z;g!f(>>)@wlvtHYg)4&3-l`NF!5C#o-uA6ZRe)Cb7C`yR=fiKbtYN+{0uM`S)(3ES zcBh2L?l>IaPOV@uc}rE6O*?-(sydM`JE(A|?o_O85iB_f z2L2++$LW6Vi^H%}cc09n5!2mK?Yap0E14NRl>XP&PRm5sd#*1<>#_%t0A7?AyUf{D zm@#POfAR4)$>L9y4Vrvi!N<{J$rECNU1Z4|L$h}*`;v7leoRo`J1O6_)g$!v&Ob9` zJ?q*9t;C!Ly1{s5%{@Rf{O!50&$R^@z2YyIqnnqDzdn2E>pO)Z*?tbUe~EF!C6l;6 zG@OLORP<9TcgiLGLHEyF|GdYRf#enC!AK%2^{*;eD8;*no$!SkmF9CV(-o{f)>wT~ z_V{)@sxGE%^Mpssi8u{sc$=S1Hyzu<7&6gpVU?(N4K4>6wZ;~Ft`d`wKq^x#okJLZ zAL9_02q+W}3xz7D8*4$8Kxq?hys~4eYs}wBlT$${Zy92U2QH z&8uI3U3{2op-(1+D(|A{xgCtmTQ0k?V!oW`D@Bc=E?T#}+VUgLrGefxd9L=z&VX8;@PlsKIu?y#nKk$0dXZsS$@M;1tbT6G`uW`S|k2hL{ zFwSW=WvaUa-3MoKfTw|SXI`Sh+7=tRRdO13Z>`a@m8rnp|ET6%Ak%~nEm`k?g&{k z-L#VXIoreO#{A=kfx5CZ^WXw=Uvuw8#JOsJLMQ9F+|{mK1z}Ibv;}IQC!SLaX=nj3 z#!CfyCNKV1oNv>f@|xL=3AinZos!lpk5cL!FNth4IooCDnNd6xs9dE?>yZ1K-nsCl zHJC(Zc)v9?QSK|48j^6qBG>s93@l1~Rt`%kZ3NQ%Mlq5+rTE4IR!3%xIAG9kkiwVYQExTerD5bH z`#P}v)rTdoCfMFo-i~)DQVT{l(=2jTl#JFQ>ty7Y;Lpr)H=45IAvELNMk85L3&K0=iVG!8maZPHtHxJ%4&d{RMPv8?Pqlf8cwd`r(b41HERmA-j(MJc zO@Kf7c7nmN@6dYE-FWKZXrF{K0AaKYI?D}hZ(H~Wv#||!Y()Y+Px&G?c7Cv)p<}X~ zWUqXnFYhtBW%)5yC^0R@*|V~7s5xdK;2?B!Qpbm!9oI;Ph&Z@}9t^nrA~OLgLHJa) zlJRtkkbnuj6xqKCLSEx)TB?{jD;N2f#BIVoejkE5;~PSXps;YyA!Fij@F1)~p3UV( z*j|o=N~mrSC~D&}>sZ@b9v@(ZQu-hW)4O)>!Ln+-#Qu z%IiHvLw^VfP*GhoR@{)W{$8wk(g@FO+Fr)=NqrpUWtz3;(YVYMVgFp|D|b*6yv>(Q zHeD3Iip4M-366PV6OZU0{(#3;(T}&lc|`bB7H=BG={u(g6-8aWB2{%+4)sWJkV}oU zy`vr>9ZsXQ{6;MO2VJheVeW!{>=!iK+p+1ciipoVa%5dep-bxF@}2KqRp4)4lVA*5 zzxj9Y)Xa=bkn#-ooo^O1^yy9CF|)GQF*4D0dg#Y;h|i^nC%7eNfgej>`v*Oh|GW3F zW~0k_#KPlU!t8nXL$!^0Pz)um35Im&0WAc%2%ajM)@eVXM0p(VZ(Ln}Zt*a_NNt8y zv$J!0kaqVjo?>ZFRC!wDX~#t#h0Q9@Qh~d5fk%5m&=1|;VEa@Pn$!aLtLNBH#(J>h zOVx6O*sYQ5u_M@q$CbFZ8H_gy_$$O^&Gobk32YWikbe_C5`thAlwUB&e0+ZIn~a8d zU`Yyn{)gM&{y8DJIE@**vwA64rBtMugIUea7@80!rmzn>S76OSsV3eD)#mh#wL;F^ zFR`b&uS&WS0O3=~bL_S1%v@e@9WS^V)>$pf>46R$jn)`+YO5A6f4zfGw?W_iTVaha zzGe>uw?pMyWzL~q1Z2T)&H7jtZcf2|{yz5j(I1QIPdRGMVP5qDqZG+!OMh68zH^xMd4B3msd* zwElZ8nYlP%3MtmH8khe0u-maoZK{#v|S%<#VrE1Td=P z!F$}#!O71_|4id#XyCaynfsGez?Gs|3l_f6U>RG$>U4{TPng;oFHzeIYI(T^MQ7u! z9z@1@tsJt5$7L=tI7}E3v(Aejku{Q!Dl*8ub^kXdMk6p3XCTLRL-{ncgX6Cih>k81 zb9@CO`qXwpTVgJKVQh33TOh>CLrglZys~cC4rJfgSs~TEe;O>l>=bR1zb7f&7+n zqxJw=T?o~4VN(fUmTOd%P(>=FO%{&w{a|&DNHm5sCMVI2%1?8pJMez~{`=?7kFK_U z7QwKFVvE$K=ou}~E=rymm!Ft$(_>C?-QbSQBGdZQZHAwzrMPn|&nauf>Z6$@&1Xxy zi&~-oWFGrnW99r(S{nZTCFp{Flne^=972PG+TryaV!P7^5_85nXVk!PU-6F)4wSxqSSaif^3SxTi_Lrtf5z zq!g~PtjwhvKw?x(EJ4ADxb4eEsA!@;N2yxqKS{LC5CC)w`mhuFjq(cb^y z{nQ=eQR|~|vmiLWF+b32#p?3vCM&U9=W8h6_pcUKi*-Z$i;)gO3GvFtG)%X@_PU8l zaX`&V`-8!&2s!^p$EaKs;q0>&k$&(TSGGu%#Ox{IBBM;r>3w%?%eb&qE6*pK1a3G2 ztcr9zc}r}Uw@-4M!7NBu!^`YqxjF6Cm618L%CtxXYjMnAe)%vpqIBDi_;^lSNplcO zF9&N6PKZVDPt=6o`S=nD+I=2P;NM{8VHk&Y8%bC`4E!s;U|L`i5Gv9%eW>zMHB^Dd zK2n40eN^>C$Wt_Eb|9lU;nDp8gmJ*>WzYZ0tgNNVN(r-HyB4hqnZbl zpSLfyIME$2QnMAC;#=CBj^PJ+o=UT4%kA*1MsPm<3Q7RY{g@v={U$R(&lIBP0FJuO zkfL__EP#66=uY#9$+Z9jk1*eLt)(^Ib#C^DiHd%dD%nr$?6G3pvwNZ1!k#JlqZLY) z5Z$`SM;aQcI7+lW%@dm&lcXTZ-U8?xd=z>>&}v^kqj`qC!mlV6BQAZnrf~sot$LTg zeU=7ZJz$(*B&dZo>&6TStUbCsE*TzIF%pQ5X((Noj!)BqlKZtY(MJU08(Q~T@B!pB z!hDtDo3Wkt>)-}xpe+xnjkZF-vkM*$nMnf`1@<>{H_%_jgEdY0$Z3A=F2Aw57c?jR zVcnSzfZ(!AQ)u+jh6X>G8hcAMOUS!!oComMl)tKk5_d`y5Dld;F(;L`KLJ*WZAXu0 zCdSWOO=D1zY8p+^g0M+?*%b;8dZOQ`MgOgs;}~1Ph9;+_1s;#>8?iX(A)R3@d~i?Z z;9UrBUjCIfxzjzVn*7U-_Ad%^x9JN*)BiE7aCvb7_+^8PsG(?y-J{i?*?{oR>2B+K z>1O!r-Bit&*P3f&65$pjv*&=9tSry)C?1TeP+O+NxYUhh`Q$zCpJThMiAdN$o5l8= z1>^N{3dlWNpB>#n9zY2&B$pK4?==~BQj!$>Rva&n5zyR-oByqH@q6-*gavXo)QpG* zUZq~80PQYe>?!VTia00n6pun)j4G#Fp!pP}v}yhY7*!OoxvdMQsf%|Dc{ccSIA0O* zq}bZKkI?N2+`9q0Y7Z{DoVFQ4y8?eDGN)W67m;UL%iZV1+MIm_n*RJ~Sx)S@e~gig z;Ke0{m+jub(W@xNdyh=Chuy`QaR9MKb{Xi92%o9wo$Ya@s`v&5{S7*ZDy&X|0vJyI zI@jash&yqc;oGVAe?0sV)c9Tye|=;Bg+V_0uDE&{QH`-CODXvt1<*C=vrtLSOsP0+?LmY@$QoaV|+upU#h5 z4IRbc!D+|yG|wWwz21L3L5-Wm0N2nFS$Fp`G2EZfF8Bp`A8?S~{`R^dZ#fTuAYjlR z5;B{(f;A0R4p!bTPRXK9VM3R6WptuX1CW85rrduSPUxM=CS6GwaxdzduDMF4kn@TV zsLR>zJs$RlF>eVF-@mR7WsFFu_?R&C76+Jg>b*Z}`?W=~spMOUiowPoDEALk5iiI! z79$tqYyhQ~6chdRr;uFed9Mvi-w*-s)&D>=o+8eQq?LL0V%cs#iD{jl?dWbTcTS;) zh@&$9u(x`>{82s=_k|XPvewmSsRXk<_y!$tqxpb;Vt>;-p*0 zzB$thx`DMHWKxf{+2xEeKNr2JU$Akv7<+Wu8t*(H*mgUoyH!E-AEUm7wZ$+-Ny4pC|T5oR637d}+DQAH{7%HBc@4*s+wJ`L{a$ zKeEm`tf~Kv`y1U2B1j1cqJ*>}j1B>j5~UF(q(tc&9V#M%g0zxSLqQrwC?Fluu_4WX z(PMd@eSgn?&vX9Twa>O|*V%RLe9nE}@7L=z^8K7o;C@_hJ8ELGh+18&vpd)|UHi)O z^M%6l5=Ee7+ zbP-kSrHf``7oO4@SB3K3u-jE^&Tr^k&Ek;3`h)ym0xw+G`J|q%QO!(*T2VHbV|3YJA_aQrC z?D(P*DK^fI+_YYe%5o)A$A}*R6KT^N>@*4sUHRbMk~T$}rD`u&VKmp^PDrvko)5>& zcdFFaRh_{`H{wj`ZE-6UHMB~}zxv38DB{O@A#)COI1VI1IX^Jt zy#v_JYhrmH<~Mwgb6@#DDU1III|ggGpQF6apCM**^A=5J1SS))>guhBwZZ+7$qk=R zP~;w^-B~XF(vm!XCxBjoOOyj=BCP6b8gv?h7sklARdjsk|MT!32y~}QD+gqzwCY?2 z;=)|InF>MfHME7ssS=SoPw$I-^o9U&?H;8uriQYk#|IbEoHQDA{=8gJ004RlE>Y)! zu6rLfT3nMQ{VI8?c<)%Su}6g7HX-#e7p_4;nzjmNDY3qJ173L+iF3fj?uibO=tBHt zv!chdZt|-Y&>99S5a?xl)S=33Uv33|ju9_bJAVn&EB#HtG&*$3NsamahJ7dlcm z_(^7CEM1u-oy4rsr=AgoUIZC`y7DeMT+?_6=&6&q%8+J~S2qh)|0)GdeD^PcL?FE> zmDn?ymTPpKB}1trjUeH0C5k?Z z^TKl6^A4PeotuO9RtW%x?cz9vnVG@ zI?1-S68o0)tp83EgzPRgL)LfGi3Kk{5@y2H{kj7Bvcf z$Rr5oVPSV*C=~FdjXqad1tqq`6_2|vR`Y_s&r;IK(M6DdLqAAi)+!`UJF4mIXq#JQ zcj0K-FDAhb>GD6B>h;4Xp)bQ56d9xA8R47TKFF*ZIh}mnH z5P4kcRGK}Bp~qB<&tN^9ewuLSd@w)|=*rwK4{Pe4x#uSw`F>CFh|$QvX{&VxF?CiW z9qFm{`7#s5drqNxz?)8G30M(UvVlJ3YG_^{q8 zkKM;V6kIWh#UQ5P{PLYRoKog!>$CeYY9Wt@C%ygH zI))#_7Tv-i|45hTrv0h^OT+DuXod~*f9FJG)N&fP;0vJu+&zL!w3H6e|A_$0*bxQO z71?vWmh=X{YiifSN6-jXn&e=H1()^fz1Bf`u(ILhbCz|?JBKSf3I@R}yB7!t2{(7< zzCs5Inu5gk3*nTvH&2(&gkaaGN6V6MyU)BXSogtdJs_>8H^IS@_S-)*)Nfk9%XfG{ zv3XVm*7IxjZk}GDH2DxScyLy`2@i&L-=1l$Ji74VrFV#Ch>8{qJB69I?I@4fHKzJP zKJvd4kRv@Ns=6+hJ1yD$>bjiT2CD^0WYk86kDOa5J-~k~2W-zDEuR=KoHPE_S`dx~ zv2iAXoI!#Vaswq%v@$p2;q=*d-_({|n*KzB#9v&0%ZtIrlnk|qA9j6!#xa&vwawU; zeI_=o&~RA|n%yTihad3A7SnV~i<(e}txHlzWmRRZ5u=L6k#Pa?FU3c|-lPcrU=ZtQ zTcoJp2ICKrmOomZ<0mHVpFiT?5scFxj0)$x-f1F)F%ePzk>c7^Z|P{(ARe%@KA_Gq z_=lKv{~id${nu-p8z>UPNJIwXoE4?|4Uc!(JU@7O*Dg`GHbgJ=mhwN~fZPwE44Mx> zhQe^`a|^CN_gP?eU3b8jQ7w3L#1<_O_IB(B@2L(f{TLd+!`kE_HtuzMk3a999{L{*~3`;nrkB_xmfFC?q_bs00?lgm+;epwBOf zqdOPUK5(^GFP;5V{4$i=qa$+bLwQ^hI3!N^+Qyk`_(!GE5M$CLQ+~P4Fqk)?h5Pq? zL55s~-H*sjXjq5SzNKLr&HD<63Wh5MA9TTYAH^`O8A%?RZhCR9i+L#Rc-$V8WKRDr z43U?dTh0)=4eNm3FQOZ5`5ngHtMBi9Wp;M4i~8U9J3If{-%2Z>kwV8;>ie4VK3+Z7ldJ7DFv~ z8=GJSuqCLFE#koorIa=kD_5>wR1ej zGT8D~U5R20V9_&Yb!sHKe1ZpbmfCywl^B2K{QKL*{ZbdF z9x)VpYkjYO9JEE_f~<8&9mD0#V+6(YiYKs#cX}wCg1@fd)_r#xsO%hN8Jf0-t1P@a zXYH4!#E;Kh?1QMru2P8hZh?{x9--=K#mu!q+u!(0hym;!W?``h?qS>|E)rVia`3Aw zx$RCU?;0OJ9BejB-M28@Jg<~k;auF+toT9bA=X(aMdQ>!w(G}{rQz$n#f?KotB*rD zp9&mZ>*zWT_~iUs^micEAL)iCn|q4oEyv&KX7VhUOswh&Dkn!rl8 zO_n#$KFPzJrDwQ}+yLS}#iLIT+Y5o;zZ7k(lTL z2YIH0duJxLzNZg7za3?*I!6CY)on2HqtwIoZZ(k~b~u%|!wF_WMo({|Az`ilSiV%ZK#gzu1dZ znP0Mhvah;rcj`R#Ztqn}B@>1WL!4p&nZ$#P6W_m?sVoJjo`%%*j!T0p4b&wCHvTuE7rKMkO68T1Vnj{MvvWg^Y zoDg=-N)ym`4)7@63?u62QIu;bl0tqR)fsD^S#V59Jbe#%YW&NWCH}Jm{!c3K{rzJR zZ#d!t-?ZlPj`hz05-H3^M`MsCJoWeqTN_Npg5&7Nm6OdqaO#)``?c?J{NTfk-pc31 zvDaDtlU*oLQNfEwpcbmh^q$mwop1&dU;6A166z zUULR8PjLtVMrlXj}I~L zL$LX}qokw2l;9-d_yGNTHVN*`KcymD>=9!v3vu1asHY!pXoSDN(HY4KXdVs0Rss5Iua~bXh!1soSEdUAO7RjOXmOyQ=r-v> zy)0K#_5G4x1(bMO$gpOA$HhTr(ryeO_zX&BJ_m5+`zj8ztGA*-2l>8elJIO9ldDD> z(XDKIS(5{ug*)hojIuubWvAJarez);)w1=02#h{n6=O%JLSL7@Du+lh2iL&vA7tSs z53N4KVJHpXv{a}>r=Yt_j7xg2p1l@{4jy;{{t9huq&`Jm{pfx4WE<@?nF2K+Bt3FTnHVTU}$}tqyOu=l*z>wmJ-~Md4nNKTxY>45^?>8 ze^U|OGYZ-j7wsVYQ)u;4ne=tY>$G*8q2+|mRhqsyPkMj&gMu>3U3uzXFFZaJ{SkRc zUzrmCnfM1fx>^!?RJ*hynfu`RY7l-qWyK^lu&=0{&Ui~;Uu)&)_FD5{t2NSJIwkBb zagV3uesC~}`n3|??BedP$q*x{MGpyvPnIXn&_~?CG=x^mK&W1do%Tg58E_+=*&Cnf z$;2mEK_RjDu@zG9-<|R{>6JQQ97jpUy-)vS{}$xI$?U}!CI%1S-slA$&@EIh0-rMi z^IM8~qxCBVaHHJ(gsewFiF)XFc#sgMtFdy{@`dO-LC@;8uCpfz+33-~1N}4A_`pUE ztvRd4tFKQ(?rz*WMd+QMnfNbLKkRKKiB2_mPKrvyylV<+&rVXWHYIDT2cA)cI>5^o z(a#y58+ycEN8sl+9?DHasy}w-fae#A;e#zt8+R`kg-WW8K(&OHt)CH-Cwz*SIh&~- zUsF(|vWc{X2R7AT)cw&u9s|6Exy9pO$&8C`FA5l@Y%q;(Pr;*V25u$9w(_xz$c*CP z#}ZH{auQ10&#&+}^6@)HPx{op1u(;EsWwK~@elja@t3`j2Y+h_JyC_7jEHrf#@(jA zaY)nG8T{Rpc>6X|0CV4;Sf%jgV!s;DnU|0W-#Twz4Bz-|ZR2|ZSp;7w zYOkGXy0LqWev=*O&YH=%1fO}Di@Jc9(0`ZbYgQXgf};F#YTB7s-YIBRerqMhY$W$Z z&*o8iyMVwtA35{_HD|7iKYpTFzF|}hKf8S0e3bIr?+F2*=C543av)=6#%aqvdi~xP z%Jb}h;Ockd8{&XNu5=9>r@copi!UX;FDL{IKlis+)5KEtkBCS=OzGK=TnGvQxmgaZ z->ndM92F$=`a(WjxY>#gT>qI}9)K2nJ;T@lLdODP)@ui$VH2nD2!vABBnf&PJ|`_~ zC7i&jV{#*)i&w`10R!nz{ozw40HmX;|wn9mo>^Ssvq@LyhV-1M2RaGQgn8Hw4f3NkWr% zNw5|n#O{(x_&Gp?!f!wI1?mA&cs)^dLf>y$r7!JImuTp7iEEB}tC7#u{@B%&`o^H9g={_<44xqHV6i=T0raftUYWGT=0F0fNbyt#Gsd1Wtx#X!d zKwCfAfSF$Kmu6=m2ZX*Smcw+=!4!qHVYWJ`cKYDKV>z88Z(+=7JvPU$phLP0WAeCW zBJ3*3&NqB9A=R%{5)x785(Uh}Y%TGOIg+S^+n#NOhloh}`!0;#gu08MSAOn)h014`~INY2RBVr00G%WBEo4ll${tgMidSK99ilg+3@rBTdVCqU35{#2N+tr!yoAUTBku3LFiuAg~Z z^7V9^yV4|YKj4@`k$8mPI@`(e{fxik82$Jq!wcgJ&U#NCb*E)Webn;ymjh}ks(N{W zv6xfGilLLvK1#yM>qe&Y&U`=Q?xLMv{-AoX(D@j~2;!|6tI|_0Aa6lxbb^~jMnca3 zd}O#9c60XifQDJ)VDJ?#MbFoFFMrf@fiTj>Pk{x4Xj(RDsdM4fElgL`6~Zly-((B8 zcLvmkJ7H^=M@pTEZ>c@yd;cXz=4$cNi{OtE_*Cbm}bxx6$MplvYzz`olhpZ@!WZX81n>*4voMN>tZp6{%BWtD1&GYM}mYp|&M zoYECZyGNL^>b>}L*5%mzF#A#c@AvUu?^$ZOOXMBSB9|6y)!j9(TXf#Z6Qo2vYl+gR zH;kLAG{u}2Txdp%YvSL;Q3gL@Vm+XEcKOcqq3S01pntW`j{v%%Y|7lqeR=(Kl3Fts zOv{qgpXsMVwg2`e%#EgJHz^?J1!4AfW}PvqZDaIaJ-6Bep0FIEJMaewN73cSf3(3Pg8&A8(33!gj!M8u0Wl+*ttsw2%8F)wXL(vD%<_=T?Oc?gI~w{np6~09K4@= zM#>%&A8|rW0DF)^`RD|UczW+iYolJrPjGDM{q{qb!m*LVr}>Zd6 z+#}haGn0WC_HTbxpE+}i^93WuBLnjukX^Bh|1XnRl??3Ner6 zKRDc7CIx;PJNtNe0&yerWUwRwC-?*kN|r#-^NI-f=%(IE3mVf%vq{{eS7y*P74gGy zrt_Jo^@s4ZDE$jtFA>aG4+lI1e_c8L%X4U*X8{8;udLX)tW>k=blzL0CAv0OZrp}Se}f7p2HXSp~ab=6^Cm`A)7N%-0FT2K3a5n z=l*qBSZ=Ldz^#f(bRco5t14B`fA7Hlq^!p!YhfAjUZ!spdCP*ys5$k!jYzA;bxs&* zGQ0U*AJX7xJ7Wv%u~r-!(Mwim#~u(d4pP(OKDtm@J&n|0tt6bfoU8+aRovZG~isUM2g{hUzn6_23YSog7<-IjjdA_|pK{Ptw%?%7g(_1$89rf*a@ghrxPvN_*v*XW5E;s*d ztxcuf0%SFy8=*ILf+8H$r}0RVnGHRcRIE7|Tw)mzJCn{#c1+6(seV=u>?Wbxq05pF zn33{TsFsPQma3{7rQFXPg@Vo}qiAvGM3~9E#o=&Xe3MGP1q8!+G}7yY5rhinK$aMq zBU|k=Fc+LzgQtFmSIRD7%72M?fBNe|o12H=w3ZX_)RH9x~SgP!tED-Q)30 zx`VEvFWlxv>;(4KV>Wh&D&a{_ziea_>8;weyf`90SSmPi95RfhkdzWrs5d&npUg(t z-N1*>PhoX2Ss2)h1Nd%PZ|hD{#`(H0N34IDaEG=Sw+GTpc_(RJN7JC@4{X#Yt>`Xh zkGW@Xa!U47_U}>FaQM$<)+m|*YbcIClWMNO`BE!~H_JuO04ThAnu8uV|`$m9RW3z2OFQlsRvpyKBy_eo}bD>pr zeVj!cMo%Mm7g_3VrkVJ?Se|G(OYCGxjbDho9#x(-Q-7H}FXWzFs^C>P%qj7fVv0e{4?HR8Dx?tiPzTQ|)^7iqrcU;z!TTv#y?G+BN(2;pS(lZlB!QXR2hz2Mlf9T=0&| zGhbY&6Ljle-Mw&8!Z?_XLn0b8VCH=4=vUY_LR?|c5CA+@odDH@r~9ga2wSsh*h!(2 zgFx^yWsVP~`tez(;r^}X56j}=U6&sN2ss(kO2N;>hYpua&wEv@!%1y(I(`~CuezG zwiUP%37p5_RiDjn6$1Ft&*469=e#c?pL_|ra$Pr7{0GQ6d^b+W)9|J8Wl+aW?*}OO zH{}>iEj+BD>fw_Y-8d^;tuJi+usq=<59tr#)qSZU{}=9B^^JDL-z^`k5l2SmHR~5x z@~`T-hSmo!2)wskbb=sxW*eRkwU=C)yy+7eF&+|Q`Hj-z*)l5;{Ka?~6Wo6SoW;Bg zMnjdGC@A+O-1$UhrWN#9T0;8Pn9u0uE|WQP7<%zK&Q)%|TpN~GTQnIxEK0rmD!3*C zwyCbZ=&=oH@DiZ!Fn;C@GQMr{5vb#MD7$_hOmRgvAL@ml+PNiVfIjLeXDEPvFnDYO zCG>0OFBg*_7c75;EFpVNg9V@5hjQlkd`^X?T%S;g8u`TBynB0zgXrwU1AaSQL7%g~ zHy(v<1BqW=d6&1lMQm`j7Q52a|B|3~A(!Z0psSKz9G`C8nXS!bDyCNx2ByE?aHo7p z(T}H#{bP>XVFIN^Om9Z{=-s{Na(Iszizl@;X3Eo`!8{ zEWsR?GVovgm5mGW3?}eC6ahy;@s>@Kvs7ExuTp+Vag^|%N|Ka7Y#sgM-RM0M z@(XzS$Z9QBagGZz8QjdemWG@27|(fs5*2WDnr9P#F!o4-(C;Br<M>Ht3n_XtLl_M+Tzi^DtH zKrvSRS~qxujc$R@+x>nWoBA>0B=DOK*q2y!zckXd^e?yg2t}fDQ~tsB_4f!c&UQDq z_*tWH7?%%u;Bgcpc|W?p`-X8;{Nmd@)^SN2e^*Vez?`3GQk~M?ygBCqKv^nHtSx{G zN3222fTox7)Se~}lW34ovgHg$)Hrw`M2^f-=uZl$)iq4Xb$om)fX35Lg!gQ)uwx!* zxW)JixX`+Lz*68y5iq%YVkX76?o7b4_iLvC6@pO|EiMH~V(0rT5c4ttF>8Ws>3>GM*X)pu|7O)(v_cup&(- zwrHOaUtF3D9Z3b`mBOKplu4bNP&Dgf{bW`N`)oc;4CVF5*dm!!eInXYfDzvNW|;xC zm=TG5dLodF-Z?fGOZFFXdR1zk9ac}gpxyLwB^2OE<2mvQb+i|ojbcY=G_%2HZ2pgo z)8oX?N1rck-~)L)y_d+othLsuG60Yp7+^o}K>~MjbMbrVwE3}|>XGffud;oRG7VWX zlgP^F+_K3tyE!(#3xVpiXV|N1lL**UDe&(2b$yGk(|OB>sm%3U7-Svut6fq%m7kaB z<=dEnCEhdEz`nU;w`+N)%hYxab4EMN8XU*&*Zo*{&GbUU*?MET*3Rsj*KJF!Q^o_q z)BA=Z0k0vC5P4;(Gl`V~->&b*%;!$1K=zC#6@3{Pf0NcT%-kZ&Plj6e#7v~@mPMr_ zdx@Q@7ks#%VP)`XIAS)7JkV;|Q25hFI?PcQfdCU*Ic^r~y0Gkv8ThsCS?DGEabJxf zv|dJIyL7@k!_K{ZzPv2ee5MslV;ZizRJexczchoqMs=fv?04-B>G9Wzlp=z=d( z1Ku?Y8rtO5wS{>RCt9j?(qUQmgJ-0-u@sb6myvLjxfCZ@l)R!5l!ROR*<_1Kx3 za8<0WHouFZEYJKd05wl=U)<6QCgZ!lq!>DKYh*6~0swIS$ERA0gz;-wutXoN9F)9~ z8fMLyd{AQV(zD857TuE^3@SX6XtV9uK1`q1;khUz`EFf1+1aDNUK~@JN}XAL(_xS| zw_Ulsl4(ETzRKon-T)JN<}II%uL9v=9}z0lLEl%vYoAj1p_pj=z;*A`FxrSEEz?UU zyO~Wn!2sIFr2BDfdF)@W&$o)U@E)}>XMvk5PDN_KEb6x9_vc9qS{1j1CoV@^jZ&ua zE2%qe-3vp37ob< z4yQus*^8GP_ZFZM)AtACZpG*Bo039HK_ui_VRZUn@twN&d_?PKpw? zVZL`eM$Sq9hQvxlw6w_vYLsiPynF4l5v5+2ol2Hh1iq!e8Z(>u3cy>PWUHriuW|PH zT0KrR56{C)8F9Qjp7n}DoGJjZ+SX7Kod)||?u|od*Q}nR56^^!yf|z;#}L4wwe5It zsc!zX0|r|AN>?HLps+KMRRH4cnUfXXF{Lc*MCkwI4qbF&t3-%HB6rQn^@FviUsp13 z)WGO<&_0jYr3B&Ci^sVbxB1RUu+7hr`Cwr3mGtta#`%xtDdT60z`!L$8U!`BP1;We z9QXSJ{D=hlAQ^wDK0_`%)&^f}v}qvls_F2{dX8+0-qKi*nUToQVMPu+gUbjg&A$FA z#|U3UO7`Ecu7K@VA)0^HkivnZ4jTWIy$&Jh1ruWC9UiN$TyyXHSDuAisbfqOl~%i? zoSM^(Ga_-1=AZ88;DG6>zpTEu^WPDEKqAKBvLQva1I!nd0E!5JQeW^oA5c0Cu7F&= z+kHZ+!81u%3csJh&?iWjYoT7@vA3?`nv@WvvXRl--7^g%Bs&W(vS)W#mL!lytN|YR z0FU$RpcOaHAPq~zoW-T1uu6@X(`D{dcG*6{w5*&>$oYe$6K@pThr9_#vB*q{f^h-! z>-G3A44t2`Fw9wK)Ie&#Uu)TA4;x(Z^1E1B{=*SfPkQy%ALd^&L7{FbRsOR@Raq;= z1Om;;%){IE+1bI~uBBxn%)(X#?Sf{aj%_6L#1D^sWlJ2R#Cr=rc}H#NYzXiZd2x&u#Ok1@igZ+r?DTjTP+VIn)I+a>N~SSA@3xUy@jq& zlh$LengHmySKh5@tMqxc647bb@5pql4E}j`Wea{<_-H))2z)QWOk^SBRy*s|##Rzv z`8FQ0Ybk4?Xt(wwAv}PuHs@g+L{QdR$4St7B1WgQOtW_Wg5XOSv(K@ViLY9MC2P@p^HFGM6mA8ePJ)7!)!sz zvRIcT9BIA;5ARA>%r=o_^Ul~j6w<~B$DG(f}?RW^IT`;rH9WJ=| z#p1>Se#?C1P-5&MQYHcm;aW3b^ViMyK>roe|0#-q*9n+7?R|A>Q)@`x9E^x(KM zmAITT&ErtVwXwy6Sv@pkpUE8h@Cnf` z-*}u&UPGt6tv@LBs*!q$Z|P?=BY$6e?tJHkmsfUv`V+=BqPg0wyT6on4cv90U<>3O zki3SqENSAS+7I2xwWHQ|jxl$0l|VQo}!=6ZF3P z)7xiaU7B_uE``x82+a@H;(%FCp$AN18O<3`{2_}Ux)%0R-2%qx5WyB9mSalQmrU0z zc1V|38=kL8m2+yjh(3jUbA7wpPB}I5p~Z}7*IJgpwXgPWpbp_#iBfcpJr3hIoTXo*OSk${(w zbtvijZE!3vk)ywy6xe8oAnk_U{XBCpZmlM=`L;DB%Ci!B^NADdZPHDj%&cbiwhkr3 zqv6^uuH4W0DHfNm7RRxXyhO&xmx?VCCc%s`0@6Pcqm9ZdtQwYCzT?JcLPGuBrin^J zX{p)Dya|?ADZeGieV2tt43Q)>&HFdwqJ?#A?;+s>uJ8l8%K}I-5_uf;hR^DdWHLyEMJL0JrGXRGqQf?h|#etfc-x=m6V$;IH&b_Q{#xbqUFujI-BsHveAD zcCJ^Zj5HlHl;R3OrtLkiKj^GxOAcF;rME(N^3)8|s`ONTk}v^6aC#X^h+CadW3rCANvKk)#XFLK}F20Appha(_0!2WniH zG*0X!fu139^LKA#H|_Y@X@2ESFlGcSWdhUsOsm@ZV4*-0Dc>-Bv=HrJ6y{WTgS1Wl z9@ab|R-rVxU1{Td&tuhQ?ATUlX4UE@!kU{}>YDcwtnekiW^%Vw|H?|*+_v+{q@mpP zyHdBNUJAcptwo<3??fPf$=H+BDPKR+08VKn#V~d^ZT8~V&aTK}Y~QnyZ>X-)HeVWo zthSAzQId*IFo2(@9NHfAY4*M|b0{$1}P9b*ci@Hkdxc`Yv*%qZ=LD!O+J zQHOW;Qs*JyMD5mZ(@Q>wp)su}n+*#%;``VFB}1&J2Ms`m^EH8Ppmy{8o@(|ua zYLa5 zY|DR)h~4QZKuJtsJW~7%rJlC0GHKwNN62n_@G8sK-96dh#4}2ENFR@cg8I4%*D@vU zF71iUg?N%e3~jx%ZsB_Hr8XxG>`Rc#J^Po?2nNG9ky9bt>I5a%XUgtm?lOUbZk1b3 zu@gfUMd5q?U1?c9j%F{cggLzSY!641?zet!Mbxo@qEz$hIvV#eBw6wRKWA?TnPkA} ze!Qu@y)?yn>^nIy0<@75!(!{NwsmgK)c?u+2uGj1)BBzueJc$0mEiUF0it~Ib
o7dE((Dr`W;8Fj>*ni@V zo5b`tf8M~TCw`F`&yQPae%%SP4H|naNKS`<$`&>4uQw%C521I;t}UUJFU@PKq;K}B z-OKmA%zFM1V=(5Y9Ff?pJC`dN@dV&R^}O!+>begxJuz?I6W`0X>MUBGg?$SRUG@P2D zIaH%njzbN%84UA#aTN9^R~O6v)mtbq{g9m+$#&;@V%R6iaLqYv+Wwb7x!tX&%$rVE zt9EDoXwN_3pqph(13w2d?eQ9mTIN1>#qGK^!(=20qLlv)xve&ciD5^~Yyv5!tG}Hb z%oK+EcMr%h!_p_kkBpM^QrQyeDGl@j_Y!L!3!Ao&?J#V3jJZvzVg(y&2nw4;{8u64YBk$^@rDdB{Xc{xvY86lYyH)H`^XVqjV;U z95Nt{-P+7W71=(95k9D9=0gJHlShu=WYqzkY~Ht4gdOzlSUGl!+}}Bl`kH)mT{u%w zoBzeTa5~vA7YA(aU@u0&ZR;D|?Wg|5n3th8UkqRL*^Vt$4+-}JsEEPtz#C(JH~G5U z8MrKSsB~_7;Hvbi_q;jWc(#_iZ1W;}K@sEO)gJY}KRy^s72;;!ib`4;a<)YU!Ai<6 zQlG-V!~f6%ol2%T*Ub@Qtm|in!=61lEH^zfn71@J{0{f>a!Fck&rHva3D<0AhOo1N zd8@5`@q=SOa2GPg&}na>pH6%?R@6f-*ReiN(w@`7`s@_dM_R4!z0&dn$79tl0@2lo1%mr;r9MD-7)Z+PA({}ynKxoji6GNSx`l$(!5 z#W}FJlDY5s^Q}2ZSC>tqT%Ga7ONDSjVRyfSlS3LKOXe=&LZhfw6p~_o`u-{8ksDF) zkC)BU*cool&}`=`4x1R*QRt>Tff?!!i~*f$7`r{_eZz=?g#_U;Hm=u-Q|GT1Ln>A$ zzY71Rxb!hlWXAie-Ud-;V8l`IJ$V1^Y9ozSp7cVRMN55I6ky`gvy9L9&Bhel89>d9 zMZdUup8G|muM0FNl+SQNJ-|~~9oSCaaxW!=v%TB>aOC+UmvpcCyw(g*7+soIj{T~8 z<0cM}K+6vDYCy|AnO{6&Yrl_t{>b)EWa#X^5q(W*>h*smi5B0U53qgbfU)%#{O>=j zn%+x>8d0 z3cpL)-_*Sg`Ij%PoBW`ysSD#$cHWQkk#(CiTZB8sz0@Y0k?``0RDkV5UM2LLIPhbP zixl~~S!YlcjCNpLV*k{bhhwRJ_Jg10*L$j9*Bef;zWv%2QrA*~u63xPMoQ8vU`+&} z@74;Gkz&kbqA=+X97Txz*|s`Y^7nl;eAMD2))K2O&L}s`#Z2vA#uWLxg)r|?*GfEQTi6$%D3`8j zMSRmQK|l}cvKBCEKRQ!jxzM=O@K4ai0K=kh+m`_npw>2s`iSrB#xeTw$r%UM7Q+0x z=&1=%%s}n+-n;CCLepn1BJOm|irm4W#NZpRm^(JBC_nh=oN>Khlstq=IRGB^3ckYI z!ZQ)8LjL<*RNBMNx>l(~lkmSW)TlO3KRr!$Q+A)%;XpRO!HMMjCFG6c@JT1sm-;J> z5A8I^ogJukbL4*e3S*#=v$0TX_&Evg_Z_)>+vWQ0Vj!h=HzbVx)3KUX=Ea| zpAtioZeO`f<+bs_qcuCqEuI%#nf_NXbn9%P^B@7!dN{#T76mmt6#5}FqAYZ+l``Tt zeT9MUSj!#9C`TZOtXBK7!Gw_4s)sv0KbF>ls^ zyd)d$nE;TYk*wiSqfy1tWl|a+*R;NWF;{Aquh^aGU_I{at|GdamGB><#*L4sER?R{ z7G1H+K!(~aZ1E`)+fqg94frFFHx0Zo7T1K+K@ zWK1@I-yQc6$GxC;B+S}-C;`Du9eJUg4&VVcJt0KaWWZ2V~;G&Vs*lcl-`w;>RMP!lvn;efi*|V$igh zb(P7B@jNLR+3&|S*&gu3Pn{d2G*lLLoC+uqUzv4-dXe@l%)fq`GR&$N=>DNoIogqx zkTR_qVa=7?={CRk_cBC@t`om@^_Dzk0p(rUZ9ZauodICSLLuCN7;+1uu0O~dB7)2p z4dlZGU78|RxmU96{D5WJY`0gEdZ?CwK#NPJZ2^X1ZAZ9X@!-) zm;bR3b5nIax#tV@+FfHnFnw@>4mS;%wQ^>`**$h!mwjFaLB?+8XPvJOh#U-8++bs! zzD5}-o+6npu<3Ma{i!)+3Y+1%O#(*WRbB-119<3`lW8|kEYvryw~wA7J|R!zn!&)~ znG|(##3L0Wd7b#EubSTn> zb~|)qI+O-LX?{*lA1 zmo9+U(IAj!npTdmnS$9MS2Ww8(ZOo~`Fx^gbP$K2e&XEL0|AoZ|Ff|s++d+p1>H}l zBZYeSR<)^9LeS$0N`g12;*+KYe>2R{iM`yofhi{)geK#-E_PUw3w1P zzZQG`l&tq`Y!}TtzR;nNi@#RD7Ju&wO$1?p@{rL1djM8^haY>PlYQt;ahm{-r;iWc zU}95dQWd1R3M%ntjv}l_v$%a1xv!<}$Z~HvYz{2@WN>>0M(Ze6VG;z-JTxFXTn!-b zOkN6)m*99U%!~Np9y{kf{h054b{&pD+(xfe4io3dgpH13YA6Bo!wmG0sGwcKHyk`V@39bB#k{7>q!hb-{Hebq8sY5BCnBMZV zIN2n#HROBMLrsw8bulUm?Mu-#6z?alZ2Kd+sDd3+-e;$UXg}4mghtp6|9ZA77s91v zGR<`8YJZpQk+!~;ZPZ1hxZU4gZ&faZN_|!6mMR#$Vxjv!SA;!Vj(CO-SDtg}j(gc7 zw_QhpW91Yu99p?6!t=yEJlZzRa-9u+;!MY-O>xzztKW$;Bp9Et$p2IR0of5%x zyr#|CHuml;!_<5_)hW4_ph_nM(w>N6*qW)1 zhIWBao{>=KI1+>_dtw|w^rAbowSk`K9g|D&m%F$Pl9@74HeC$lBwmK?Nw!NDZQ_$J zQtum~C@B6f8`z*5#HCX}@#&dWX2_PCJV?~1^Ll+$NHBz&Y<_hJw-;SuZ+^<5$@hv< ztIB3n5*$x-$Ok}oW%Avupud4rG(|KOZHO>MvOYsX2fkx@8y$ZKX-*)2?p^kOcL^MX z8%Ieg>8I=uosjh(Dc+otp}ot@-3Fj{L%4L{RKt#F9E zj_aa%9qE}5jfr2E4^zI$s}qzPVb9eU=e`OgMDfD^4@>7A)WjEW@z9I(-a&eiCPkzr zAPRyCC{0=b=}42_k^mwoQUnA91O%l>mEI&25$U~13rMe_2MEc&{N8(i?#!LtnLD$4 zzkBAK^I2076@>FEa$vbHbPAvRmz6I2p8bay^aQExWhXCD>E~T8)Ln>p@;38h87{!o zG9jNBB-HBd?D^TggE&R!1CzE6r29~toWZyvz?c(83F#^_q>PLuWPIP?Sufn9Kt_k} zB`Fa`kA};xDs@QmiK0!P>XUXO$bDfW&hs?jPI^;8SRuap{Y`3K_m=N7rp6=Jw;>Fo zj;3XokxS%*APH+7O)ZYWrF$Rc1ZUO30;8%tB*Dzw9H8eEhD>-w0Ylmy9Fl6 zQ|?Ipb0j^};=BzlU$5UJ?_8mn3Lf)qD&s5v*Mg~0-ZIT}j=n#Y*Ku8Qh1Bz|G!YwT$s0k@|h94;EaOTgU=x)T1L-KPmq3QCp z3q~|||1~k8u;|S6b+EJy8Xjui#ZL+{K7^1lbHBkM)fos-V3rFp^<(q;E;X^U3pwJ1 zN;0HaBmD9j4erV#-CSj-4O^*BvZlF7>B#C3;{E#Z$PBR|>NBxAE^lP)LU%dmxQBm; zjmOU6gI_0vh&iY;8=ikIT3EJAHtBE9j6H3OZ#zS3wdZMILIp3?X+0<^OH%iNns`FN z?mOTQLqV;0tqaz}{ZWMo|IV^HZmfKO1~};S?FWZ>BMX!U=O9R2k3Xk{B^Z}3=!NBt z=xc~Mx97(qRJcIY6ttjawgR$gt<3c=LU`qaF%(O{!&=`nh;;JzKT1c23FAN72X}U8 z0;1%^7$CEt(CpmScby?w)`n=vJMnL>1m~8$?|q{46b?-N6GTL+@n1><)9iO!cYnNj ztW44hnHsiTN`>usZP<$IgIc7GTc*?`e_ORrFOQ z{!3#7Nejt^fJ(%v$c;dx48w?q{eWz(=9Uz~e7v*z&NXz1Ht1JmUgIL`RMZVN8t{yq zoy@@K-@L=@PyWm9tHg_%L{ zP`j|qqtPsi`xrU6@Rf47)Lj&#A2K6nw~?-k082!2S-!3OgFEdMPI`vy_|6LEB)d`*P?J z&!nCICLFjI$?&_CG`#V}o7x>Uc$UC3*4O7WBR0F(Z`KbG8t zCKfw#_?5pYXP>JwiB@yh1dRN&qQfFlubIj=ZHsv4+a`=$Lri8Th#cyjjj^5$TN-&= zsjb}%h14+U_`eVMLs|jk<<8#xh4^g4?M*JIbAIW)T;SooJ9B;?B@u;P`W$85?z|z_ zpvZVHPN_p!Ll^BA2>WHoyXE|dmilpREGLxkiz$C%k~1DmOfof(K9UNMpg))vx{Mv> z7z+=2?NWa%JMBzCCa;~(kaw&RNbdNcEH|P8MD$G)kH!%7h&c`5#84fJ2Je37p~|z$ z2fWb})rN=i^7)zce%kpVl&lfJm#T1>CZ((qvDoe^E(V>;r6275kU87kx9`*YX9C%P z+T7My_~_^q{&KCw>>J9Tm$L50@giqXE67Ij**NlW;b$f>nY_0M)*k9^37#M^P<)W| zw-Hd7k=IxH3l!TWp-B{ojekdp!}6JT?(fa~enrH0pnIKi{Y~PhTkyYyClciHt$9nb z)4$&@U)@#xNsoqkcg+8K2;D}JOlZcQ1`ezeK|Z!NIhK>n+J4}i`VtDFOqF?ArJcm- zq{3Z%+=%)mpqk+}1~1ANkC6XSTlg%2gMFVHp79Ko_*d490!Zqx-k#TDl0^G>{bDvd>Q+@EVU&!+>1l=~0Zt%~!thxQUa zFk%Dt9G1p%QACP38*f`6GTG{hITDpAIQhKi!Qv@mTXv=woqV(GaZNwGB{d4`td9or z+1x^Ha@F9f>`#Ap9>L;r&RO;FxH75;*>va{%UjG{%x>R{PcC^tus9ot2qY(GM+DJl z5%hK!*7eCfQn~ZAA9h-!6J?3`83g0m@fF_1KV?tJ@=|jrcI%L<=68{s z__)efiI3#gPUoDAq6Ip?wtOy1Vle%@Qoh_u9d{Zhk=0-+y%N~@?Tje-wEq4K#ueIy zD57XJHhyGw-<~&D)Hc8+a$|ADeUQKZq_=+>Lyd4Q=}Y+r+#R?c1%6LLv5NPIQSxd&*`x z2Sox|hYlkTX|W$4(Id$-&(sU8ps(l-;7!QJUfB#)9k9d|B|hSQ)^&VKJAiTz?ZB*Q z@{ZNh_7D67HQ9qGFStM(HF+`TbqgY3M^$h>wHec+eEMut?N&90`IjkSa`Nm~%-^6#j}A$9}gIvgYh&ry;vzKyziN95%V z345#pw;Fp;LCqQBHM^yMa~eqOEr>S67$wO z2OPZDc42`%U=!ZfsbH zT`C@w9nU5kU8rW?qUs>nyT5eeM}GKCZ-fjVD8Q(yY)H%6lQx{T)A2vDH!w!v$!NWI zGL&!e>CtrBZk61CuMp?eXCOhQQbI znYwuqvr~`4hT%*Gr!hoT#Wmi`JkVz=rFQdC{srN_wArLsbqKu^3OW)A zYJGOK;O|+NR1FGWH`yu-ry@@<6T5<5Fbb>xuxA4L$P~<8Jh(ES9V2@BrJrqU;F})` zajgtpYWmrt**0Y)@T#@8QT9?Zn~V&RPs7ipQ$3oE;5ESJv@Oo>b@oNr6~#|;5mzjp zxGS}CV{_eX;>hai)!jJ3T<{1eU}o1gl1o0*UH3W}#Hk3ZIVd&XShZb;orXY-m~eU@ z=%)@yT>1oErVw~f*TfiLYS3!G8Q&K9bP=-fJ#65L{|E4F@$#T?$-@H-sN*vZH4VRA z7=p!#X>lIoJC8H>g!s77XEEBs5FhWmk{^^7`)TJcfkk5%9_L6lY~usAR%P+!^p=NmHL2OsRUn-*#rp2c;vP>E>`cDrwrBeMp-|FwLz%OC>%J%^~wy8L?nIL7+O-rR{=em+vS#UVC{5*^u8p6 z@WTh0(y3Eh>rs!ruyEWZ?-O+spbn)yIjG5(((Mld&qxqDPnY!rJXODaYk2eocK}->S!|8{YmWgD~QETa3*e?TU3At zZql%S83smm1L=*C?(G%JLaw_X>4I5*Z>*`)>$8sKM#$uilnYm!CT8QJPd=7exwZXc zR#R`}gri1vVr227e^byC56M0grOr?=xMs^4uQ~e$lI)kUDeE}ox1|-MCAh4MBWkZYNaxP(%kHq!x+b80XuBLOOb(#AhoLy95&<^m z7=zFd2eLIhtsfBl%=V-GZ zt($(IT%^GYfnoVnmq?$BkivfD#?-&^Bbn1u4u>7TH{_p??fK%$E=-PE>$cg#9c8pe zm;+7kJ0l0FTK0R#(MbS|%M#2NM}Vlinpu z1r0Xii$o^oCxL4HFV&)bj8yuq>;t<97SA7v%yEu~TDd<^?0sI3@7X*hc#^^6)cWSM zF7Q2`*s7Tg`~sC|h}K#uzIsH#TRdtJpE~jQriW1})Tfis4mAIlJ1*VJ%NB@>^5`t{ z=90vnW#>c{-Ykz7)4I<4${jt*-}UasAH{^lIZbmLVY1XimoqQ1kLvXscHhRqc%2A{ z^4fvR;!^1r_4>xc(W#(xYvB9w&1xSek1Kd(=S)H_urC&_33<$v3I@OKz1|^8d#Usm z4zz25wFVo2>xbL*S|%(_V*EQR=^g->5b#jJBOrJk)@3n0zkU2ykt(XLK=`M|*|TD~=d7?lawKR7l}2T2ItJ=AkJC2w7{_bx_t zxX!GdXFQL<7bO}lX9o>*(MZwr>mlxJ!)<}QRV{fn3$8@`!EwQj2wXd;3iAuAd^LgQ z>0Ud}B3-b$OP;kf@%FO|@bDcxT!HxJ)3`~5YAQpoEYX{U1YthQWYkjUT3F74TpbAlQ9;QuW?s1HDT9RY2{Wa zOYZnHz9+m)vc)4E5li!G?5LH9WqN3BGHE5zk5iEeieg5>LWqxD^kw`?yvopA` zwV=!aYFf0;> z?f>;z5`LD4zSuo^VUNV7Jgh>W4X_CCd~FW}1hTm#Pjq-ENA*g`e^zO&zq8Xc$h&$^ zTu`fhfcRGmxal6S zX+^2^r{y9z;O?>!qyG<0?e?dH7zPbvHFgw0=c^D_v#!~Z`yJK#JJXeK-ZwO%5 zVR?DA70-|{2VFI1!YZ5Y-neN!(qxo2azfOpb7d>3Z8Il5&Ypih0L`vw_(uJf5i0o8 zeF^$&`z5g4-h_hQN5QagpvAVW)c{x`*jc<^My`ArPf&n#JNa-|?#xA-ZMq83f1~pi z2+8yWs=$=8TN~sG@w4xnJCDBw>hX=vv&4ko`+JpQUpy-Xd&=#>eNOE8QHU_g(iTi>hWsv0_a|%1j7{c*)hHHYP)- zHLYabkLfzrZ?GP}*(hzZ_-fd=ryV~AQ@E$bdKscSFuMSqXbDS}&}`~V*-Hl|NU43} z&Er3Fsc;8dI|27Hpbe)Ii=zN{5Y#KTuau+X4CU)Q++Y2Ap45_wQ_0L&7WPKmJ4Q`G zWXq-=kF<$!rD22*Ws34>v-zLEpz*Gm78-VqASUFV?ecw6GYI>)EnH+l%XZ5F*jXcn z(iZe`@G+++0J|Zsj0{i2LK3frx8I*iuk)fYmtsrQX_=WOP4~jTHSYi%94GMSLce$M z=evHMK>vYXyO2jWp?qZA;6`2Gn@?~&$!~qlRftIJ4Wk&oDN&OP!wS25@q(!darti6h+$P)2 zW$FLEYzbEeF1{NwD*kN31k0`jpU|@X$GHaeMQ7hweu}y$-4nW#@-!EAS%lS^C7)4FcP=!2zgJj7=Csb zjsTjonz|+E4azVhNx$J$L4=kkj)8pm8xSSMOSP8PA0HQwD|v$E=t_ebHdvfcU>H-w z3>7mRPrTp~bshnr0~|jKKL1|q0Jg?46B57Y*TwD;dlS7a`_%Yq` z>@oJN-BZ}f9pc6}w@+~#HtItxR;@zxg`zJuW7qq!32%g*{^Yk|$B-1Blx9lzQL|5& zE5|yd{Xe#-JjDGFXmW?o{+jbuerZy#i)0EZrQLlGH-udWWTm$4<}#=6*GCL>7WDTP z^Y`2N2Eq{few{Riop;yYfwSbBtwfjzug zAW<+B3J@kk)oBTm?TbrmDEi33>pSj+C^w7-0`? z9RsNok6u$Ia8j|jy)iug6|(%`p>ET9xuo-zZ!d%`4(ob>*DcK-MF^R{LN!$cZu zvTW~}+a9c(5k{0xUsimI#H`qSPlcV4(ZbK5vt`bD9+a{?xEs!zRM9{8-Q!%UOI=Uy z&ba|R00nB69~IkbV^@#gm0kLY3+cI_soeA!$7n{c;nRNnS^eF^|Gr7OoWiWa1D2vr z^-~_Lc>l8W1}?!NV6YjiA5L!L--y8?py}8sBc{;+xc%}cvr1lS?rd0W=v85LZwT`nF^t8<8k&V7WZ^7m`>srS>*xCj@xygiv!mJd> zf3uaM>2+p+gif3is1awrG-NmqsDeDw(+}=k`>2*{Vxpt1Os(|b!t{$TwyW^_E2IP< zs~GyO<+;_N>!wFYF)|5?5AJv~{H+x56#STQTSe{`2V>G{Qrm+Cq=3dow9y2Ka)GHF zJqH#MeesOMO9F0|aH11n;kM`H>GyA4*XI4=@>98m_GL=dJhkL^*Fj(4A=p3COe5ug z028Ws;Ub?gW3tqbv?eIu2O%W^-oSy)>jXmA!J$`OmFJ?+gn2l$nTBDht#AhE<@5E^ zF27eBJ=#@)%dy(q_+JTwhCNbLB;{idgpfSid(fILMwBsRvAREVk(os+7F3snWpA+= zsxAXNmznz@Rqh%E9R&Ui8U-kS3R#z65%(gLOVv-sT6U79HJs}1ZxUsF^G@(yR#a~+ zxI3W$w4|s{7uA;;tweH06%R{keck~!v5;ZtMoo(BF*`K#u0LQ1)6b~Lo!7#R78wgY zVp0mgunR0jpO^xtiPxn4@6QwvWB3!@N9`vz(HGt=vc%Bo72v?5B1byNewZV|^|RKW z%VH4^(k(i2AC#7(J-7(w#he|A`GXOTvbK@CfY56FY+VtKddZCZUe0O72rH?!{d$?;{0L1H-|uru8Jtv?IoMS{*F`w_Jzvan1htU3% zEf4g_qoo&UKr9Upy-=o&PIJ%3!#p!72;`AmniWe+3q^@_IYl=E-B5dvVJ+&Eg;6$} z*z#!c+D#Sa5)UAle_Ts0P4Euw1!G{L_8s8t*%7B_>Jg^1(9^J+j3z_M@}g@#vhcgk z4Y7{y7^X#RAZS*$^2qsl8$*keAprDAe#G=0BnuMZi52NP#bQErqEVov_I1>P!;4Sh zckf6To)SX@V-Q*av2tY6Z?yU9!^vYjeaN|YpACcPbe|ZfmrsWJhECU3d*qD_9Est{ zK-JQnh!&cBnP~AF`M;}k1WB#uB|i^>3LnI|PuU(5+gHNKNEy%NgDzO)Z^T(^|6^xp z`hWiOe;3QvzW;*Jc8=%-9k64k`EeH>#yZIhW3ULAKBFE@ze=M@i2L>UbJ zJy@PBGt2VEn`0j{=`9sXdDi{K`$wuQEY8H|Rf%7dx9;m+`TtRQ`OD@okWgi7jWK{9 zQ1sT#%z1kmbh%^{fIF~DDL#q7vQD1ZP&&%4Ip@SE4LBDwX=D3@?9d(#LD68nJo#@G z^|1pTk62nQuH1kI3H?7OJ?{l)CW+AtZ14|mXhmO-)=zMp%I`-lt=lt9Foi_g?WF;$ zCLZs?3uusrR9pX#I?^YtUx_`At@dqmYY9c5cO|r4tMZWM~3q+6Xj2D&mWLoTK*M% z>;QL>{?i{-=yI}o0d^VR0 zJt+42n#lV>#N&2YU6ay^n~+$;2MmEUxxYoj^SOHiunudrG^;SZz>xjo*5iD=WWl~< zX&FM2govGnFNIRvldOr#*#uF1C^XA^7R~)vpe2Xy87@H<`<sQ1ET<10LiAy;U-_uFB-6_#B8fEOSXrN%!%~}27 zb!wFzso?-7g&(@p6ss&aPF%1jGU@t_*q>9Es{ZU;i{K$}=KB=WII`+W{M(Hw8`z|$(6bwh(7Yd0_ zcyfVyMh4#^`H4>XljUR+7;;Z+GC&X`_6=O91sG3|{9C9RH#a^R&G*L}W z?Wwg5Z5yti?jeWNY5W6uv$P~99;vd3!8g;NHb?H79ses~n~~8*26TM5ocW0_*xDb- zZhc6}PxG;A!Df3~Yy}O}SN%ne$fgkz4MveVX*bi{vgy&wE)LDZmzYJ?ks?_zE$ZL-nt-zKg2VZF_F!*`dO3=yU<%9F^&P>tHwv=Lsh!pOO%ksMl@zNcK8eo>* zyYE?BM*{EL?^8g<17ky-`|W{ITaq$s1d}Wy_X(WCJBV_V)jR$Q5MXLp2!h=~7QVSE zEJk8_6B~y6SFUsCy;u|FycF-hPGpJ{Vc>Jg@)SLfj5SU?*xO&y{q&Nn9wPhxw|
  • RFScp5r$>pV?ElbmOwPx))NQD~djk{rRueRjRZ?GN%DK<+%w&F4i1lmUbRk zOzt%R@NNJjdg26lx1dlq= z4L!lMk7bus!T>2UFHHq>_meNhQw;N;DG%cq)kzlq={Uhuk0=JYlIBKuF0J3$kr+S| z&4g&7Kdebum|y0-6l78|N3JNq=})P@1XR-nkG-`lLm}D8{3Z1sUKCyyJgT(+$oicc zLD68@#H#>&rI%my2wkFBm4PS4C*ayDP{>uKp1KGVk-mWsJ3D{x7lr|fvlO$UcKE~WHQL1g83mKAs zoPh*LuIK|tv%P@z{H%Wio5){Z@89rROR|D1l`4bEQV;^YhK=diWapB>(gFDKsY!&8M%BvF_H+wZkulu6dfks8s)gk>a6% z^Pb!Jx6V#JyC#dQGa{dbf{j7%ffI9$Y6yJyc@KH%%G_YL)Q`b(aCICA(Hr>L>z@q< zp56RszkV0OV(X~U8i+aMXK`2zPlb~-)`!W$dp>`f7u5hFyaX&7Sea~yqUak>y|6H! z+W8A#tGKbZEvUI=ALgDut4O)drGI1TW2c#vd2QT=i{Z-0OvZOGiPEysAp$yH(j z;I`wT)h~=R5ySwb`u(4y+IPom4==Y_OoH+!oz5%!dp@?pyOo<&Nf&;i1Fo;xsdGbC z;d@qildAnE*Jx(bzjy$Ow~$bDud?Ou+b;R{x+jsPhcMy?c?zEjOZa`aa}A^&2(z(r zS3H+v&ijb7e?6t8$A?5(zoloo~SCUo~g;*U|qD{Sk zxNjK~m%r@w?hCBF?|vNtSqKQR(bzV%S9a zKz89`nifGy%Ag)i3pH`RfZK};&Jbs*PXF9ztBYWXY|hEKKZD=awmw`B1F)TfWIiQn zN|2AN(fddmG&1ZedI@voTT?g5f=t-Ym61=G;h%{;4)v0He3L{+Z`PepM$ixv-uJz= z>m74y-j8Ii;Ed-6T4I0~{n(omgv0p|v21r>B>(W~^w3xRvIhWz;L{6WjfkpV&1#-w z?Qvz1%G3`S6u{N-6#8$7lb+Snuf!tT4}9hm-_)0-w4lN;MAs2POf(g`__ALD=F^cm zHXdvXOoEqv+d`~&=$r6>i(yNYuQmh?t0qA+vbx8M6ReK;HR)dmw6 z6aP28A^6kJLRV+P2Dd6~t~sPk0Evd%%z>Ywt@;^{v#~Jh9X+fetm9qcs}q+p-2H~1 z-k3mYPmQ}ljDN#V=-{sgast|h@Y|Zd;viDchU3%Q z{I%T={R9f-sw%|-duxGzWJfoTk+U-x98^bDB(l+8?XUaO2m#ZHz=?L_xKx-$wB6R}uXWG5NvT?l7ek!DB09Adxc zxXi!81oEJsABSSxP7|sIsbB%|BQ!TspP0J)kB*4}XA7fSKlEWybt$^$&{GM>*b8Km z0cP=8GZuCEt-u3y){T?oAO8KR5eC0(6FfxekfW88!Mw~Q5V0K-jOPjLatZW0?5y4Q zRXD9rU#aUN0Vh_zQhOCt5s2_2mPMO{;is6UrqH`@$<;|?D=4gGcJb{C5<^RBo4ZQo zSO|X~gTeXPEL+0;hJRq<`k;_99V&4#;=YwXNU9 z(oagHiPWG~=LK&`3R|h>rdAknQD56F4u->@|D(mAu{KG0?EvohB`QBW&z+X@mH;Pp zf?Q<#zzrb3ulw5(Sv?bXYOi+UXIt=vAOCScm)bCi<5stpf5D!Dy3z3=L7L z4RelNzYS}}^C{lV8@04198D}!+SbJewu-x2**Kxfe^N73XORDJ(+xzO1xuIzi;TM3 z%;26Q?2aM<;q7OjGk`klz#F;=zv7YWav=XCNlz;;Tl@O&PGm_%tG6gxcy9ITUC!N{ zgvk1bZ;{bIVkn}TnMZ6?oRz@&+WT`Kz^j|H(ihn`&qtaeCy7>f5O)F;H>!Z@ALm+@ z!(LCVp52QHnnM!PNcs;h`n}frO<7NW|XOY@y7WVNZn`$ zl+~Ja%uthst|&u4!R?C>(Tk8ppuqW1Ssg8_kR+iG)Rq&lx(w>W=!L+>D_{cyzks6v ztmPr*$@5|R2T^?_>-&m2Qy&tC@spb>xO4h%6@)M5C;J*+sho4i{7O!_EC~qXmFJMA zm*M`pDw)O=(2*)|+q6Q3ftYXp(9qO4fc>cp2f%#T_=}I)VQ~E2yPA}oCUH<`%&-8L zg=oa zC0zkN+InZty#3%!@Q5e3k4K7f=X)$~rRE*O?@{x{#`%)y5Xdnya%3S~vUhQ~QTD|n zje!@$9R?Averqtpw7=6oJgi(oRz17y?>{H0(o;Dy^H!)3c9|p}L1|i}EK?j(35*5t z5LK2lUf|V=u=&HYyfa3zh3zVn>C^RMBtF_`S*sD%v9*f5>OeY~8MFd&4>VLKb0r(c_oQzuW5~ z1w-vReajqP1 zXay{%F<)(V#9zAE&6sGYZp()_m3a$%OCRsQ{4^!m@MS)IH@f63g1F}%8PVL{)@b_$ ziPgdG$v7-d`&HhVvt6gA3%1|nBMh$lXPCe-SwNA$oMINV}M)sxjVOYoB@#(cj^e&U!wMm`j z>DP>+D-Og@lMNmXjc;{boZr-pKM`AiX)uaD-BFyim-cDA>}EioeC?6F*0|<_)2q0V zx~JfouxIpaQ3dxRK8j1hHVEIuX%h~<-}Rzr={k2)vXm;ZlM^zQsAc?)k+!&Nl0*k1VBlnJF2}gN3If_2 z&lLuTmcGPaq4=vG9g}!-`YUmr(LR=y2jT|cPyY=^sk~9=z}TS14Za-wo#*3Ry=|nL z)90=P;E0FND}1e5yZDm*bvo>OMR2<(Z_E>!#)PW}w-Jjt0Q%_0_DbmXqjgynSJ_Or zpZD=BICb3ky#ecu-yej^<=swo`~$g00#o{KGj9UcHVPMG2bUdK>RoYKVt%2(--DU_ zA3Y{AR%!W`Ws<(sTgfgN;Z9i_}IS|2j~zG#2w z_(@M8D^{!43t>TGKKr$Mm-aEy?(6!;NQAF0r$Nv-9BiU>1y2m%Yia+*Y4}DQbwpw>k=R7zFv&;g;_`Z) z`674D4)R3GDIR!Nb&(g-E^s*bT;rvGRabtpOq$NS$C+D`1&(i_2Q~V-_EXK4B7|9N z5ox7<1P7KJ-mE^l>w57sJmBL_yIqP*y2s+2F49U2i&dabA3Ke_n5@_zyO)Ne2t$6P zgPVBb4mx@&Tl0VbV@x|V0+3Tnv%D=Igy)I3T6nBy-54)YHRt;6G<^UVl3bONRRG$? z_3D8wUt$#m%>S%#KQgWPj;yB0&V$vA?tN>7&W_^2#FhRmb@@QKq-cBVF>}&6El+$+ zKMYB64mtAAE@;LAXQ%PS<)~%V05uP9ijNP22_B8zb2|yR*ujT?Dr3a2Ezjb1x08$9 z_y#HafX8X@`K>AY^2y9${~Wppdcv0^HjgGVKAjS|EcmPk!Iom0mX2;<4g*{LDM#Y} zunD}*5v>1YH7QW1vDQd% zy@Lh5ZT+A#=hcadMd=0t3EWo+Pp%21Cerk3FfP33C;sEnZ8K581v zv~ZaMS3ghX70eaoLAuxLVO%r5VjjtgKR$Uw69mfE!f<0 zGq|7RiP`-V=}!IjkgdIxs1m-rmG-G>6GH4bI3S#kx0>^K&cfr zlOM=)q8-4BE2q?7ofJ?yJlV}lcuVx8t|?mxZy2sYcAHj}>t%PR0)4wS23H-%w*c7| zvM$Tl4KPW}vEaBG3|3w=I4R$#Oi0I33oDe@vOG;LOgWZqqPR1kI^Q7zKNq|{_`>xH z$U7$^YsXgf?H|g1_568$`-{}7Zbgv_0bD;Dtt5s%fyvCU6586zNL-Lt{#+vCm*L|K zG6ygua2#ruX87W9wvxz&ZVLKlYuHieo=xt_C zrl;)zQ+1q0zB^YmvAe*uaw!tKVh5wx9)`vJ;k>xrs35~1usMu$6XkVs&sAK_eY^w~K`TVt`HAM(`sjBH*!t zQ@*y#P%jR`((HbfnIf4r`;EVA_WRoPNtV^E#;;abJR3Q$aJGN@^RLF6$d)9DC3v>4 zr*yUK+vVny3Hn0ss}zS>oGn>X8_`V^_ik*334R;V*0X~It|{JnWoOP5d9l6_-L80!Z5a91DV70@MR| zV9>nm6oYwlw?sTxIu_qYe`mSa$ipo_G*eB07JZ3kx1{Wg%*M{alXDS>%d5db0rb|4 zoq797?yl%cF+C{0 zFo01<_gN}qzfS!?uj9P&h7H}xC3<#I)Lb6EflN1>xymJ;nBR-dzT-RDVY?r?bn{)( zk{%Uy(}(Kp`Q=yB^_42mTzk<9mWLqs^2%g28{YEQ{IwhVsuNVtjiwe5kiF7zGt7y8k-8b8a4HIC9O9(%8dVGKlYR1D8t!?zD|^OXlO5JLFSZQ< zTwXl7A7R=P{F7i@I>4;l3f&?O+?c!TDhuI4Ok2Kn&4b?d|MY2rYyt-bA?I=7K#>Tv zaKrKN7K=sTvYq5tnj}BYy-xXamnF#&wb@$|$}>XoEv#<;ER#07zDDg>=b@!PU&Pr4 z8eeV4bjg)`@GztG^M7?W47Yblh1v zA<0UVS4Dt z5C~VPlI;ZH$JrYK~ac^xpDw=_2$!jtk~tQF++}I_<9XS49JWyX*pFw ztpnE)I5d{fkjtTIXcm>|qa$!sz@-;xXRR~$*g(nK?Nha+Vg)qVSNyUn$y~t!FXDXE z#a0^1drSs*eGn0k*WP+EZ&nA1+PE{o$HWZvGy%i14o zP3f!}gaDKSJ)w2wy8j#TJ6!lTE{&rdX#7-<-b5XJ>Qp0jnmt*L8e>g?H^o?57k`XqEG9E9cA7(q7!RY-NRTf16D!OVd~MF4rAV(W5ESuPhv? zDE&o(ijWp$A>ZbHTF@TI<|$q9r5S3gDK7sIoYU1MSUDL2S?_+l`T$&pb0CfULQUAz z)^~vkTk6K*wR809uc!}2%rJo6ZWvla{&Z>rS^i@3P^)||$^=uR0-?eK^}Sz1j+!L+o#s1Y zQl6o@ugMR;4`860I9s%zjqrJ?0*=(~>~2jY-zxr3O!JFuv!;uqZ;gdlxKX)hoF(-T13u2 zeYqWDH{L8c4OENWEh$3IAIxGD2eTeN&IP*puSV1|7_S0~BgL-rKYz~rVuXkgPQjXC zw^|e0Q-QDCHeZTPT4216FAHSL*J4EVg7Sa~R|l%= zA83Krbk#4zyqEbn@~|ktvO>aiMt~rL`WgtE?dLtYqA|wyhKHQ^}goOc4@z?#?wFA z=lPAY0uP&?r*oetn_DG_5gqb|jAsG)%sF2_7tWNMJ(E(OcNxKw&*vkgRXQqe)uFrn zt_JS{#{ocu{}+H3=2xYiiyv9l6kG9*Q&*$BHnR@;buE2FvBLjuAn|_>ook_^=f0@_ z1PA`tFew?e|K0P$q#mbU2neK6tCa;Ylz_7zHdTVynLWQxMQkZ0q;s=oE|!GO6LSt% zKqAl}-n^_S*X71rxOZ>v+pp)QD7bqY8SLlsW0zY=Rk8^XW&qD(`Y*nN)ZTrec#U>q z_f;O)G}xj^_UW{jApwJ9K*IW7CZi|3jTC&c0*^Xdj%AG0tyE8iN%`ejCo5avvt%8z7Vq7|32|8zpkvAxMWN zP5=f`9jZLa51Cpn;dIOivK`C1lD3C0W)4M=QIN4^)B0su22g*cN8%GsBM_z2kn6a@ zwGc8=`EZO3gj#CkMprS%dBEQD~9L9EW_#Yyz*^!j@zBoie=x!z7 zwCtESK}0?fz-kiLjVka!=FPjvdR=+B5M<5{W>N%0-xo;RgSeMpW&Bva(tuv5Y0Lh3 zu~%gzVV%=gcTKNGit{k-n+3-yb~K5|h+`~;pfy|lUM`~UFI z3Iv4r>WZc?Q#2V=l_RZzi74Tc6+Y7}sYLM=6=wV}&JDFaN)Typg#}3u< z`${Z72|X-WXr$A;3(gVz_ZBsXfr)rvX^s+-PaI1~SR=F!n0mqced?Do;Shi6gn7;$`|3~mPdUwI$ z*O?0k(_f1n{(Kcm*6ne9^%~=CNe>+A-LE3hx&I4&uZ+CLb5&IxjtZDgW^ui3#7BZy zOx|8imr+3MEi1`cZxgn28A(X)YcgMiUb4{=k^5MXhxC5>0X=l5G((cU4uW|TXy=&F z{~pHgLro(^fOE6@Lpyax1*0+|K?2n84i46PTP8yif+Yz31Of$`lQ;hZsW)F=3lWwo zb5dS7M`nAFTK8mNAc=$p>F7Bv$_L0PNEiEdx%7m^z1$nVJV9M>vOD_%uca4Bss&!M zsO4$6W8KLa-ZQb)hv%jc*0aZ4EqlJ#dDKR8cjcPU;Y@&z4_w`Fvrq7*CM9Ae&&}W_ zyptq!(&${y!h$jEh2e59snZafE@$e;^zYOlvYpBc&IHu|P<573QGMYa9=b~!q$EWN z5dmosr9~P+7?AFgkQy3k1Ox@?ZjjC)L_q28L6Bw`dVq;@`QLTF+;6kjI&)^9v(~J$ z&wk(M`B@^HP2N0wTlF&G^7hxtxvf8>H-644Rfhg&C9GBq%v}Ay#Jew{6=&&n=o<#>I>LSHEhmiy6@( zRjDg=J3+V{c!xtsL=kGUJ8tyyWG20)u9E5;19QLhuJUsX=}@e~b#)Ysr8H=%nanMj z#2MaP{f5C1e4ihalZ0)4OJ`^dmM430!(Llv30+|!!}A_{u!?_^FLhwlctNypUW9>< zPEh5CQy|_1>aDG@a=&49A0@#?0P8Aw(3|$U!H?Fo4kqkeQY!)Z`Jyi`Pd<@cmR+8Z zM+xr0eycRxJ!J08IjaOp)>=8akFE)}pOe`by$#Aez$seIHTt^9<4=WL3*#?ArL0fm za?~PS*)U(5Zo5y2=__wy_TqloB~ORj8}1v_yQ7m~imiSZMi-U}bKSkG{jnNMDo503 zR=^wi)m^2;gTPRNl~~{0bt0U{8ZArMI=pH#>|0$Filj~g8x7fCe0U>2gX?`i4Sy;> zpUd$;$>sQ8{J$fU-jWB$uMB?Qs%iWC^^jW>r7O;1uJ7fbceHPgP^d3w=FM5IN~#eN zd$>8G$?lvqU~%D4HekV^9=CELKId}d*7b{@r@bE?s@*NcQ z(zl2um;=oW$4XL?#$%SUOhH-vJRW z;+4eIx2C-$&dtVPgRW%VD zumX2M-rOzL@Z`8)311N|C*jfEuc3H9%6J#gUR$#L>O zLVxv_eG(3@4Cb$})+-}i9xfo7e! zZWFVT2s*Z$GLidRzs!jc=!CKWaWAtb#E_Q|nJm9LOV;7(Vbq4w(Hk;D;o7@QtaOve zUul$=p@9!=iV&-eg{@LV{EN7jtb8py8a=a%D!B5rQFbZHvM=g3i$fqK1;iA@&Knh_ zP-;eHl0tmj44=~s13;WBXGL0`d0R`K-NQK|F93~0S#N&1I0YGoeGr$vpb>|*u48ov z4P{`xF@_1g>=kUgBHhS5YUdrv^#t|~V^ruL{${GWf!}Ow8nNI z@>qYulIGXK-baa!gfhzy)LnkLHn**hgqs=TAb%)_g1$?Mq{uhl9J#DID}AYR-!BS+ z7udMKOOqENg*=MH<%Wo~q<`a;*1L-IZasB?BG|U97>cB@T{6Zdt-*7?4mB6=OhJ}) zTpLvI1^94yE$>m^yds?lS`xYW9c{6#FM=gUbn1u(Ze~$FnW|SubWATgU0qmWgisYn zpE@`-GynE9Sz99|;lPdTOMS=)Yf#On9(q@`WynGZF!#=R|Ee`UhBQ5nky_Z-EO?OG$y*1sK5b_LN9(g>CvVXu%Md)xS`SImi(OHDa}=p~@!}=c=FiSfRdo zX>~5Ew({RU^4iOnV80Khf1(-KYgJQ7y9ghCpbJun6Lm=Z6$S@AI31ebK1+&y&Y+;% zp~v_0bD$+he3Oa@$8E#ReIoXJw)vUEB_XeI9!g$1+Z7DIgY~^2|0vdX^7BapKQjR| zLFh!?RQ#lTRd_tu-1fc(d|Q1zZTcrQ&6gv=0CMqmKO7${#OsW-xOIbQB5~myy)7${ zYrX+7Qc0P4WtmS@Hr3a98cd+%E@{ReZwaYZkw{7GK|!WXMC{w8m4Ye_Jo4XT8qJ<{ z`&qCNKAe29rt#7;TL9QP<38-dNA#gKNn?sZW}fIg(wmXq~VnV;zF)EVit%_!?;&+}@I5~cl7P)`Gb z$P0t{qVrI?h}l)l3j7t>`~I?7!PSvp*t(mc5P%8=FT_9=itMqsCHD-Mi^aC_-*p+yMh*Syc1lr{4{zw-6 zZ*(|Mnk}_BJ?OzMU+B^;xzFhA@lE)rb)Olk2)LwEO>D~=L+S1Db?z{=aw?D0gmtJ-uq zp|oItpkRZBJtHZH=7TPUT#k>5$bz5ffM?Rr>xpyb+VcDLY|XhCr`{NK2{;M6zl9 z5hJGiwtHoV0oxtnqOPxOQy22)Vsp0}GqtC75C|A*8w&$&TA$p$I z9fGSBHS=UC&Y{Z9PNz8XVd78ATL13Ht*sKg*BDWUpxTwxJK_9muJ;V+AK5zhp#1~fo}(ue?g+zkw8?(1h4JO<$713! z1Bx8pZ@$DBmF*cSsZ#8{;*)DL>TH{67D2D1sEUNjY5MP2%hl*`=uEO)h)807pW4N5 z_jxEU>RIEm7Ga2t>PnhA;UAq|hjP7oD>0}5w3Ra^fZ2w5Y?*TJn0~4sGoMA)>0(Tb z);sUdY(0kC*Jlf$ZXJNw_*2KfmSUWuAkhG}-`OF_t0lt$I}+!?w=oxARSGF3Xpjk1 zXZVl0de?@jpzClQaNR`B$2@?W0X?I5^-G-=Z=9(4&^|Vw4t=o}14gbo{e`wcVAqJ61=-^UAt4dOt>>L;_ooj^D{jbbCWkX5 zR}Vk>lQtxXWI??XLhyZ$jpE2mi(4Le`BjMi&V{B7-GxJNc%oFE={$&zewDc4@>N?j z>J?uHOdR}l^N%NP>M7ub7x9_d=qAh+%vXf)7q)0EyI;w(^5qW_v*Lx^Cd7sn3B8DU&%WSrw~f2kc~e7oO z2owhFOg_1vEbho4I8HIF$gaFwRrV0a@q87;`qpUzx&(|LZXtIGbtdh=*QPEqPI zAP?L;4c%R8{_0o-hu0NKXsROgGWRHxHF|IfgaIHp{K1+?8DdYo9JaPfSubao*t{dC ztkZ3{etV`wzH4^hoeH>07JNZ!nLKqHd3_f{v<~B+e~zWUw>ze(x;PqWe}1!Q^k`$G ziCXGB{Egz6YZpS^DyNGIxqyuhF3yyJHX(^xkpYRK+H*OZj$IU@c%NxL2eKFghX?sx z8HNyaD8=q;{w1OY<Fs1fmNiN!nK`~8Dj8RIxqLTe^1cR<`JFq#IL`dJA- zQ1%*vf2q8z*z%+{BS^rBIQ{}xQs5A@I+U(H#7$Ci11aLv92BG4J*w!sXM%!!HwQE5 zWz1&1R-j!UJ>BO!VX9YbdEZa&%b9*3P@8mA$A700A%=hh3yE~GX$Dp(tnuU`e)eoD zk32-Fye#+y!5lUBL+q?Z12cR$}`>1ZF^1YP9p`~~42Bq?Kd z?N7*6w!~Z~%3%xz+ECO8+{$t?^oPH}UiIo_&c^_tp!=w-VSf@rZf^!;LY7~;iAUPU zJh{Z_G-j(&Wiwx)`dCK(3<2@sS-gu$Mw?X%;%>hFG?li%*Tj)$inoq$4-bNbq=}Pk zbM$y)gTmVSBMU>Re1t1uP-A6inv-?Iawn-J3#IF4yJbjUI|H^lsBevgc5r14NmLmn z*#Zc0a&rm~h%t_1q*O9*IDv5inS4P>nI60wxn)J_KCdT)toIGIA`bGup>Q`oPY7mQ z$i}NGaLQeT*VzNzO(wrq)0%4Qj;#s(%C;F0FXmSBh9y3(|AaV85kAT1oDW&k-P$=c z5J{R;Y3I32#%y>Lpnp6EIFFpCd1bp$8B5t?g#;Qxk{{^o7dieyq7Tf@IjYB>?>4&E zopIMhJ=QEbN5aQyq)G-Q-S2q)w>Smkn`+{jy04fhFC&lZiSU-ej{EF<(?ZatR7XHW zx^!Mz(62;VO=j-3NTcuR^tTY!!nKC??4h|t)ZeLY8#j3W+%V2QsO4zM+#QciW!OId zNn>aDLSFH-HK84H8|wZ(51NLcCQb@32}ju1MLZ=cUQ`$ib`Dm+QTGi2&y&U(0>Nqj zgj3DVBW~nc$z&Nh%lnF7dY=)sP{ZV^pAyv%bxG9Oo{O8awS&z4V2`dBCy*(}_P|~g z`gp`oMD}f2oG>AY%#ktPFZ})wt>|asZVz&b^fK83C=)&kO19$Ez>JECD#eDMZ_9hx z=*b>wc8Ch%?>Y9pMoQP*zH{8lze+O)8<8DR1Sw#UetzVXvZ{6n8JY#_RY-G^%kDlF zaa_zXa!|;{w9%wEV;X%l?jTAY?cNQ|KKaM;s!rgnL|b1`Gpkh_Zf{2`dB`D(aAJn9 z(gBov0J6e8lHtgM!uoW#`u8H-V(?PqY5Qi_8 zPLhb<^X9yejV}%rKkVTO|6H2_P#3ZPl0PUw&r-)NQXw^Gt_3;{vD@4|2%|@c7P24^ z>6$SdJ6T{JxIOh_HiSISP@!j0hVpu5Y}4Ptmqya!{`qrz0e&y?ga4dO=cUc`tIgn>vT!;=YY)1VtlVq<<(+bJF;cxP zjqdElH%){<(DdofbHM9533p*XG%CC1&-gmgy9#IHVBT+b8*TB2FSa^0N|8UchZ~^V zJX?UDAq^~NFnRtR5%5g;hH9u8SS#=QRhx_&An8@js4KGx>MxKRf1(z{syMM#tj;o= z9<|&utBzat*EVRU8=~ z7C+XZdqClb428$ZJbZ=&o8yqbTBycUw&TA|Gw*5XC(_O;w%j8pf5Ydad1ar&&~=

    D&(JR3kNL|=xLYhEGkNItLYd1U( zIQp}4bPmDzgWcDAv0&Y8K@}Etr`{v%PEELkj5qyyKLSP3C9(z`$B-(v{sd(Yg$>~Z?k7U z+_saXV&fRY*Bn2HOi85dr2OBK?*-2iRRHNiHYq}k>(jt$_{guj9XgOBOr2K*6O=GA zFg!P3M9ONVcYt44?f%Qxy@Bmg+lC5>IJB5LjH!N&=^z?GS=p(NC32?_*_RWuCSR!7 zU0(WW^DuwWO2o0rG=6Gs=|3c;q5KNfBnqB|eU7QBo0&7=>V(SHgaA1V zyNKHjtLw%@>S8k$Ys;s&PE{tPOAmb?4-wPU5>o%kO7W*EejWW$q{>)$YH%C#_or6F zQboEuk|fIE!4V-msrZ?$9%u@Y+5PcmS&L% z2LERsM)-{Dk+&Czt{akuH zKM{WszCi@^d6XUC2x(yDgd9H;(@2E0F6xtlB&_wxH=Zd?>SRhPg@|iQ_(*JAC-!I7 zerNy77a8$1(=>AJnNK$|T#RRancx3$)%{qibHNiEI=P@o?or^H={3E%@ErQ2HiUaK z5jfeXr}mJ2%86@mFN_`Er|Tjz#tWn|&*Rl_M_c2N+D_o%a`+egAEgb&BgjT8)EiOv z4HK5$>ti~>)qXK*%wM&xNE6qdaMiOz?;01qXhe4gAktPVleDpW-?ZWZv3L}oBd{&% zFdZC(kx}g+6FqVSk=7l+53xDJy9Fm%>1Ni62Uz{2{$SS`zpUoj`Jt}*>wh0gVe=Dd z&M95~M$EMD;t@>63|d5|oTjlxa0L{+Jec-my`5Km3o#_p&N4k@gV|qR z7g$gD8uWf{=$(o&gVLI&Tz`C+hWYT{g@)1$*=mNWdLz;qs|ny8!ygo!fnNw=R-|>T>`>jPom% zp;;|v`SRF_%&mQ0kzp;XvGL@^>#90m){lVv0DNo;y5>essvq?-3aX*!Ug6npkEvUe znLqbns=GMo{OjI{Ut~Jqvr}Nudgyk2i4?2Hb z*PJ`HRQ*2Y-y<~25^nk_Yrrk`6>`2+Nu;B}p~HHQ{R3W?nO_IAxGFgS*5RCQ8RkTu zLEu2x_U!@esef1<59*FZ5L*uc457Q|rerOVWs~&T@ z#>s;2BB1=&T%%5Yjd?Owdi7$Nr&!VAN0Z6}h+=apar}tVOLqy>L~y|eyPofeh{5MI zKIM#Xn)-C8?GKyQUj)( z!}~^Ls+tK7lYDe}m$~+t^dJ-CmG;2#xO=cgz~qAd9yx2}7hv)i5x>QS^P9?M$jJ19 z?~x|-Dy*P#+<)-V#Gw^B&xuZen=`+_T8qX{(V-ckzC2QRXwk#`zG?BBRgdd&XFs_m z3AA>1VK)85ZegmPEU@MVbLh1+g7nK9Bn7*E+I@JaoMM4%siJ5a~E zv9Zo7DU>xTUs_1lZ2{{6y|Ynztw5{-to^55l{D{OL(h885#zc zxOFX76QR7K(HWe~9k1Eevr4fG!{n&ePYXhQ~)FM(r*j zE~|uk7`ub_ri-EjeyUs3eYmFK*5FM&Pe@aacNceJ1g>YFs^hM`5(LNV9l_DN<0O2W%*|%-WVDjI^fr6pK1YI2H5HIx%)HbSEC-+p-gced_^a3or`joIa>ae zG0$A1Xo3Nc8+DWFW3HNGBr^DC{t@AUKpWm5P^cOx_+aEZRv}vCFH`HeDBys=>gUbt zz+CWOepcN~)uS$(v+5)`VC?Gy^{GINIdG|Dr^eh`@5WR8WjJ`))iuYy5jbFDpuO3nFEmmAF63 zJ_LI=+Z?q)bxVCzw$w%>(t9MCTV}p#3LEAU%U_{&1`gd;7|8kj=7UX2uL?E<(EAAC z>uwHJh7t~9bzBQxu@Br!cZk95m-ArGuyT$KkXw!=c)d(-Y;%IoY!R`Wg2!D}RTTD& z56zOagqMUCph;fnSeL=8#kW-*@25H&a-c9CLO8q1NL_{!jg9KPr47X5&8NeloJ}7x z#`4aLJgpNdwlP)p9*ReTd{QX97wchAjx&IAmK&FPcefr@0ThPY=5JV!h=6E6*x`w?2eAcH-?H@*e{dD?6TbdSei1=badGZ&< zX4JiCDAlNQ^p;F#=_wPs}X_w`5j&NmW%kDpl;tmXWBYQDc#ZY=I0Oi1^Bw?RE$addJa#dSLIUpi6#lfiyO|Dj=#I1^wHJIRSvRIE;JS z6l@_LSV0^XH#Y$JH}cQE|{U-%O^Q>emS@b z5P%@S*ziNxmGKyKp6F}F*eNhSV`=|VX}2>X)PAlSy9x2r888SD#kGnTT0`UZIPu0h zVJQX=r?ShZ#Qq3$`}hl3HPK7ipChYvHk>%I-|-H0Fb_T(?7Tk1Le47H>B^R+A38x` z1?b!vbrO=Ip~PoDB?v?mzhKLL!KN8QA)=Pb!rKdYISN%zBX5?ymiG*5QStj~o`NOd?!2#|La4{X8H)nK&ynq=l_ zok$(BO5xeN6RnA5F5}TgP1WB^6Ssjm*Fe(OvBB57>OS0PcSB#)3M2>*WC_ae!+Aqx z%ULnWCdSzIM6Ti#%0DqQw-Uz3UC`Ig=wTnTq|@bNUa|2nOxIUj`tVY_^}Od<`&UB) zrXarzO`TUMKQWyZ5tn&22=8MT$OIAm7_*XGlZTAx|B^@-`Z+C?Bpft< zA5pq_L>l?fHM-)i1TvBIlB&B`71DhB`fXW*e!{Kz;+#~FpETZrP)#K7RHPW~SO1Iy zs7()8158!l=Pxf-NBp}=PpM|W?|)KSKC|CD8CzSsXVa>SS|+Yd{~htc3TZjhJ1Sme zy_A?2n$47h%BdM|c0HovbWBTS-L@leap-fR7{kr=jG8R5Q9aq|?JsoNz#q9cpSJ_j zDt+e%ftosN%P~W44;@B#Nb_X0Fsv5X$ExrG-md_x4v4)Gkz0GjtWkWeP#ihPRf=zH zt!Hq8lO~gL4P3>}gYv?@Wtx=ti97-dF9;pT;_u5pwp=o6LvsLgT6?MpNr^MOuWk1NXSzR@$nrUNrTn z{b7}!eiPDP2AvIWX*N5*0I+!;&L!-UzySw(@mt)Vl(UdUu_h*~GzuEtPYjFfsQ3{v zQt?R6ry*J-!Mdwz(*veOay>d@cDsn;DarFjr$Dn1&M}fUgyk30H@KJJK+H)s=IbQR z=CryT2SPbv#;B8i+SUMG8cmOkv z`YNQP*CKgwW$_pi$9I0!JewB+|5V+`13WD$u0rt7I91d`v~x7X<5qskx3E|m3f-C{ zMEF%WTibETSCGAn-H_O0G>h%S`_rZBFL41mN!sxE)`g0zreHDPEg-Hmy}>XNv}G#I zzYD#{s$X$;!wdPiMw5yEEnK5&i$Eht@5P@Fmfvv8NTiE=x1vR*UhW{eB7;A)qOfsI z6n5f<26=mGx{sB#G#G}dh^`h*c@v)*;>lX_1|0hqpHq6EEXg~^pwSZTW*9HBWm=EH zGmiHC_wH&`z$ZHlasb$;k!5$c}g{0^z$h4Xz0x#fy}`ac1HZ)bPbFpb z`I-E<%WUHem#?}#88kF-H-Meq`96aoijhMjyI$5xv-~QY4ObGC2Sn;CET%T+4!)_ zczn#LB8je*0_AzNVC#L6;wRIEF(~Z{FNVnAfJAKFryjY zi9@ZJ$I&_mu|Et7f>&_H&UfEI=Vr#=p}i%$+V8Hn^6nKo^*A>Q=k`@+U2QtUi0Xrz)7(@hnH31SGU{4+>h;6yT+3jvktCx zpp@$GiBwqUc{uRiOUZl}@yi?vTtridi3LmpU`po7k_DvLmd}@N#!RZq9XuqK8Hh(@ zpNERfnJ3A2)04M#nK~6)Ft^D|3vFynO-gaz#cd@X?SwH+oqS6VNdppaN`$kje*_Vw zeV@;Mj6Vdeho~e2dk#xex`Dvv%s{zdAXAPX7@<@<6;ciz(te1fJ$F~_FnE_W4ZCwB z-NdnSYS1u=hkhwiC%v*y(B<)8=+nqwP^lky6EKg&f0XQ4HfkJMZ+Tb`ssFM&^Q-um z4vZMtJfeY9Kt>0OuNuYtz7Iz}@}#DETPXVDPZt@3x_H^d#&h*-Ub*WUVW&@Z3sIyr zO8N^Et~GtN90OIr$Mkblj)Cub$acjy2XDRYZY?=WAgB zAoik!j}icIPNMhC=3^IGz+ba(3g?cgAGsNFZC#^jt@M9O%N+b!WLF)BTPH*A0z&sM z`I%U{Jrtm=&hSg!hX%%&6NRQRT9yZ!koZ?WnGc^EoLhWjOnHQvqFRFHu2S&dVw0AbqiAWmEa!!|`OT=*xCHWsCtn-F8tdw(qys4)y~P_tXck+d{>Iist3C zX-Cn0W762DYecf-onHZfcG|Ed{C>tJGO+qLb$yr7#Isk3VeKIkV7CL|=zK6S>SGd*k}3Hcm8EI&N& zoOVs))v+~6;ds5)r6Xb@-B{X^wrgsIr$H}OWr47THrz>g*nRN|Ee=@;y}^0FIVL3A zsL)3B#mhV_n(80Y!BdHXltrnu|4B`s1)T0g{z|g6KdH z;k-6QJ8Q>jGo#Zl`wIic$xq?O*dFU!3Y{CNHK&ii!D;VWFzr;GKpN6sJ5hy&lMgV; z9#a@PUVZrY=V#|x@^$>qi`WBJ2)kwq@{rIe#{IAocK6OG&SXhB_d@#f9}`7=sV2i| zpU^2xnqzg>-j;Z`*3TG6c&J~sgX7v3O2tH|n9Pr>?N+~sLN z^5%uIO)}kbg-MMN08Fy*)y`8&llYwg{q6$vxS|!kb1K5iI=k2;BVFk4Rcv{!?13Tb zyr}B*IU5>bhUT{@3wFd5Cy8&V(Jp5{m+$cp=+Zr5MJRIEvv=(Sm&GMBLEj)y=OP+LUId)Mjr7x>O z&f?R~y@N1=pIc)Gx>zY+Zno^f;1NZ*tqwi=&x>u*zo#Y?Y$@4(-kHv{dchVJk)LE} z{uUew22N)`-_2{EPL{VJUhvkR94ZvBl|?V1!K*~aGdt~atYyxK3ABQ?vt7>) zgum=HLSTw!`8raa-&L6JZ%8Y12H5<9Uf#yX znJ5ckE&{kJ6F(x2O+7`VoBqN8{CAy^{chWk_9^EMciD>P`)3zaM7QYi!CHxcA3e- zs&0X+h5gOIyKotev!VNH1*Msn<9shf?}?PJ&i&$eN`ZO&ng2`#Z=3!ddt`KA`>(@) zqtb8|^wfFF@Y+9AbLM<+XHPAAc4$^Vz5E41w}ua-^s5N!iIpOWe)WC_V-@Qf@e~B| zE<9abuoD=u)$v+4yrW8_3@)OHuC3|qU+3M(KO_96E&@s`sQ$8Aar^87*Vk)C=5`U{ zh!;(^rNVnq#M!GN99qa|uJ_;whH|kK0uJVaO*FkHF1e}auCDjn)-fGtzQ5>ap>X77 zH$1l}jj)>#Hz3rZwj{UmRm_dX@m4^1B_u^KB|vC}A9nIXwYnX>LgdSnVGFxxe+gw= zf|y2rqE*s$tj@Q9&tlesftS^;8^P)!gtf3WU?2I{TM^|VzNqLBLDj>4RWv1R34y)J z0p5w5)HvB3Z|wR~8Ttek@b(Ikg(w|Gh9pCLUK1^8M-=zuF7=B}?6>u}ZnN=xz{{_r zufa7_B5XjOx$MLorjWHw@%vqE+{R`dE)(f-@%4znRz5s8IP---4d|!acjo#K>N+T3 z20_M%$OEiU{e|Ax>@Qpsg!eXGPG1rFet79A@rj5m3& z)%SWb{B_=Am91B+!FQIs7%Z)*)Sw}Vhx4Z%Cc_hQ9RW-zO_y~_aJr&mM_W5MRZwTsX&X|k3n5?tuJm& zWQ`B(eogEj+SXYOIhntNUTZWn+yKWb0%J?PKNv$OvMJ&C;FD@G1vc>%31DP_>S8;*pmpxhyL`;;jUiW=aO4{WWO{_<-9SWTkse{qfQC&>NCF6O?= zP+4qCn%nOPGfzJ40E#nxm(_!>72E;wv38MzJg@yT_Jei(_FoYGPPJ`yfoWX&h3c6J zF(wiCh(+>&T1H~Iggk$U$OZ{;4q&4LAFi>J;n5&!-0WnIk(Zetau0&?Hc{Sg=*#|3 zAM2|vO1c!^VvGDz&uN4xRcQGJKbH1A~jYkFG6-rfOeL?_3dzQ0)etpe}<>>+#x zM@?=nOQ}O{>f>%3RhPGCVd8za4BX4r?$fHd5ONApIVuE8F|u2@rpTxTD$C@-Nq$?<>}6Sx`g1?8Pq-$7E)y<&v6AEPl5u>Vyniikjr`qe^aVj8Ch(i1eY>? z_O*;;ewq;ET;61z_`2cQ0f?v8Pd-_#H~_*}vI- zdU|04FYE0Lb-p9-Kz0Cph`ed8KaTi!ki;Si|2M8yEwJ;pRN9LmPCMfCf#411r+Kf|t+0wfS7jdap!8McGOSy^ksm&Ymaaqi!F$ZS6 z0wk8=@3@`idF-vt7PifC$c5HKR@f*LcY5`cN+c9QZiJJiav;KCPxH$gr**dULT9}J zk}g>CCydU}Ene%zd*mq~X=a$H&I=reVVd4=y(GTu1SS+bzOIL5i){?<1rof#&!KV= z9>dS8JT;^2Dq7vSMRa%&;1V?SwupG(Wn(0bueR$-vV)bh#nXbJA!OTRe*o+%#P`NZ zcXvGS+U4Q#-8M{PcBNscOl63AZWRm$+boklGXqZm*Q3UFZ_1PD^mGJ+j62gyerBFY zm9knGNQL=`&~xjLVn86`GTb<^Vjmm!Tw&L;^*9WlpEP``MJoqtKGA@8RNiQ zz=UN1s(;-9Su|bUJe+auFD6arGQ1-3xIyi)T(GP|Z(fxhuedMNLY7V3J70V?c|?5+ z2Fi){oO8b-dY$(%kjvvo+!xS1R|19iDMmSj2Sb`0yMUqC3Dvmb2c09tLsqj5g7eZc z?)xgMC77-Jn56=fc%)8EbSJu3xJ}B>^P38Ao=@^RQ#JUYlXAbV^*6K$a9X7PYiiqh z=|noTMD6w<)@$8*rnKnyehlX4MB28{O{RC$};owm`N5yXz6=m^iBYK!X$F{SN6v z9TX-D*;4|P(@8r1HPfk6n4JP&yo`_lUYCRz9Dy^_|FuYz(xrS(GXbiBxAxM9qJChc zC*%%ZH7Aa-qi$eEp56_fpMn`rfPp;v|3C1!!?w?rvMhC;ZA3l{FHIdOfPdW0+8>C z^LRJHi_I~@i0i!as9r#>lgT%j`*$gAB=@toQDfJsGpqWJy@z_^&uC``@^%&O2yB(> zE@CN~5DY=GiU{G|2Z573#h6?Z8hLl6AvJ&ZIt5(*LCY8_7(PHBDR**m{W!`Mlni2v z&7X}H=&AGgKMQ>%Ph|KI9;Jz9f!K4UsRQH|aX0n7xxhZj!611#hI?6{ug|lr9#Sd= zTc96qz2bGM3z1h->})fLtF->$g~vUIKwsA|vCIWxNqQb_yt=N7q9SBQqnUz?@A?!X z>68S74GoXs*e6&~Nceo>isUaylkF#x+e`;eDAv4;vN+j{>ACq7Q7lG$JSo-SJ5_dZ~% zAadSgEaCT)afkxN-SqQ_7dCR723}k+snzwpu`=s>M&Dc;53BRRiO6LftbcoSqCzsm^B?ZF_I&_Bb^_C@hE-CyidyRCU-LEvqw+!sE^H10C-0( z5UulH>Ge4p&Gp~aC~SOU6hxvQ-Q}Bqd*gEusFJ=Z9f+(rf2KA8Gl%F1-)*dT>%Y!!17O9P@G=FUUNIEsjrbi({01l8 z<|qst#Knst&H*CwcqJEl5*PFoTJW`^=ByF`m3H8%J#r%3BJre$kXlX1&@p|z@IB#g zhku))^wU(|P=w}#M8L&~eGNDjHml+E;{l@n(V9cTpOCNsv-pHO;I_l|01t|yayXnw zLxw85`H_B1#0b!Tu5pk$9AUZLQQJyoP*hS$UthH#q#+RONl|`bY_X0Ee3Z#K_PwD)$ z2a@U-qv%epTyhv{^vMRa-c~*D)ElW((y`sAUG7hOLcYU_!wt5;s55tYsGS2Vg7eJ${d4UXp0?&f z&hPGk){=kbwG2<5T?E3+-I+VoEkRM7h3nkih={NQTlX^v=I@h<4kYTfELffdIemrh zVb{ieKzvVv#$xZqN04~4=>SjAu2Q^dgbut@=+I3V7k;PNdt3H5zb{!PI|sm(F%--vf=IykLo-|tuc`Q||& z{=e6imu&ziiL0TLZhiE?!EamFD^KrCGVwieq7vjg_nX9wPvlYluC0oY`P}L{_0KI& zz_}|{3lKdx z6z4Lw8q=&e)8ls&f+^J3$ZoU5k_+HcgNTF=FQ}SRvR`ctgL?l?$RnRo@#u5+;)2cg z9&kvKen9Mw45irq*f*G#BqA59!oKzS{$0%d7s_it$T9J?z^U?0MS05VC)BA@CkfC2 zd%nFMt9xO>e)^2^W7#;yiBk`!bP&7gOMt6d*!uW3B3qMR#pA&N0Ld;ds+YK1txL%m z`Q?bC68dCgIt&y9_7dtT2Z~qOaYYCocku0#MtzR)G#8%5L@!vVjbi_RS7!d+*KTko zN!6q@lXJ5T@Dfi$*7@gQ7MduT+^KK&Ww4(QhyRDNR+;w|_>-J+MR}maxrcxx)w#2j zA)0y0AH_A%%KyfdBl?XW^_@A?fY|^X#Rq2W3+%| z7!{*`&pj8z4)2e5gYQhI5E}hd)TdBy<;_T7_k9a<$D74bY&%poTsS)J*MIKyb24Ypzxp1QA}x_=2nF!n|o z=@S^xuMU?__y<7gq<@tpY(rq!p!qL`W-Rop(z)h+_yyv$0rEH|dKcgdY9zka<77;y zd5EpnLtbt4)~K+=U|$YhoP|EvsX=!dmcvILm-rVlY^VlYeQd_uX#H6dAwM!E?^GDT zjQx(BJDMHJB45^5N~*cl@>5b}X*UJ93??aT;vGH#yUzY`k5n%k$=9(VJ^XEB7ZWX?v+p+hHtgfpLpR4Jn*q+rdJp>@dMUQjpAsrulg`k%| zwj^nt+XU(MltIPkLJz+co#MB%KG$BRe%BW751@t~JM+aFO62DAvWn9yfgQo8p?n<} zGx+^u%)u38Nt*2)3Hk!M(Q01iE{wl>KWyOXg1Zy8n+hJY3hzYeG@nDsv4y&=wBtLEtCX@ z%s@w?eQJK-9!QIQYo8pvg3Ox5mP+?ZApbf8cj1Ql$AYV8H}qmTviJ-OT|2<e@Y^ws&{R>GqM|c`9*J%v=aLO*7&wC>dIDCUFpH0hO*GA4R!o{xwhV7ki-AEprcc2^c*})0(`C zt9S2!SonTyhJm0Up{1|aeVgdCO5-O3&8F#zaLj2us^omP5xwC6&bj`U;}yeo2?QFN z0S|%dM$c#S5|A%`tI(O=E(#_bRtju~4L**Je?tF|sZhg{sBn;9kSFX6jt(uxy|R>) zDCZG#R6qEG+2J;g+0x;`Np`a0dzfN#bzq69c)YrVCCQXPJ5wda=ws{0e;T6hH$n2> zm_~IcNZuBPm%f8$f--7pj=b)m00tMC@{db-Vu*LA6og??%*}3H0}}Z68Y?p3YtbU5 zpH?areR3^;YYMds1WHNm{dkNjX0ht$soO7XG7&K?RWA!UgFOYCUantlcFGrQg!CNe ztVIF4Tv%_TIWKkzH6}C|l3N@Jt>m`wv9M6LF!64HTj9%REe`vnfeCymy%=`B_C~Zc zU1?K|DG=%O2k9(({dc*GBQ@tjmXIql@TT0TGO0@~+AhAFwizJ(Qt)MCFgKcW|QtxuxfO^9-<2s_@)n#%)l-pCIa zps#Z9kf@75!_(}9qI_9tvk56spk-Tz_<2!Z-%7e7Q3i+jH?0QI8~4>>iyD@zo!wY! zFK0hcyrN=&+>QUH=-Ea1>uJ8!3|pFGpI;iLZ((Us0qfy|o%;rac+&>^eZ9BAH#eYu zT;R72kLB$KIpl5cBe0( z#AFW$y|wPZ3bUL|p`-<%m!czIUX=w@o^wTZ2nzl<&Wyd3cs#+OAp8=d0wt*;|Fm8N zCTVQ*>zE{~5L-KR*a!PkL1W*%J++%-9QnkZQ~g?1r@tpg(i(om+(fVkzZk7U48^W! zv;Tdak@zvF@>JF4Y~W?ue%nx-DCJ1GI)3jR)JC9XKHmL7E21|EW4SHP=ilZb9UI(} zUlzjBbN&$NBRV>kO0^^-`&PT)^cyD|t?c4Zk+(1DgAZd~RS(*X|1JL%M|Lht0+_H9 zqnO>uxWa*$X?T3X)eGp!+rK-Gb~V@*;wG-MfsXZ|c+6_Tq*T6C%Q->?8SB+owH%pR6dz7~Q zH;FTH%pgVzytS+=)X5AV1&ZKb0i%?>l!l7`Cz|Zh-y^zh!2FMgC?e~F$iNl6sAVKk zn(h5vcH+SW0;_0i<*RQ;D&BQ={!s(cZl8EPi@Vr`r158sC}Y4vZT&zA@}yb@}}B}@sSvE`RscG5nI|~7wK!Noh1WUr>7V= z_7MCZcM3i438s~n|5?C39DYZby=w4X$ZB@%Ep`20+~!Tm{*)fb{kdrH(tC;hAdBMm z<_BEWMO&S-P&H4?Y{}Y%xPH+1R1ZmYOq4a+LO@{A3Wgk`T2DJ|%gIrtO=usCPZqT76Em6~aq-gXtt9adVq} z>c0k_V%%klP_9lQ$fQb)yUp>GQ8k_>vU3!9de0F-@L&{2p(6mX!C059-RCqr+=!a7?Js~#I-d*M8PyR zTBC5`qCZlnijRe$UfCXB+`lTH#kV-zWsG}OXOZ1)j z;DH#>B3;;#s2>R8I(y+*_F9wB7u6y?`>cyutfnOJ^$P!i=4SZ6G-@UyIzk5>h}j>P zj)l{yB~Ff2+Cu-$yP)%cfyup+wp$?A$#u zs^~a&UaS3!XgY^&Nv`Dfg)4ui>g`tg2#iPCWSHU##Fb_kbTC*^4}tiRZ+~^`h&RnC z1-IW1$;@VDu0j{omq~EwT=dnuxug$@a+Ew;1{>GI`XHVlEWIDIzJ8cM>Gu%vCeUj^SgU&MQJnA3`R7`|DaP8+QG|v{T*d#zd6>1SkeYwj~{W!R{RjMqI!JYn@Lhsj94=B z^%lD%Sun~71kkvzD5lnJatpevSnLAov^h@ch8W@r1MwVmyjJ`^U#8O*D*jx7z(!R` z3qniK=N!DV|8~IuZYaNjuL`ibW(>K5-%AjE5>Xq7bjbja;GML+oeb2(+U!NP*zY2B z8|$dmjy>cRXpc51-8A#_3~y9{Jiz*OUje+Nk)-)2K3e4{kay|QToOC*r&c$h$KjB5 zuya@=9rztn2l+m=jiD%Zd`wrNi9akxato2N-r$+_*O~G-+?+SDS!02&+dA$fg=P>Bc zBBKC*Pa)|D&rh&~7oEb-;Fd-zqe!5|it{1u)zs-=(7?;x67V?&tk(dU*HfjtM|X|> z>9jjb|KR;Rr1CLdm$|;S^uA+C|2Rhn1&sW2p_&*#uN&G8v|@HP4MEd1blLSNT>9HDYcZZ-ry@}a!?V7N(MMA$G zyD8e%2s<)R4qze|OSJH}|J`%msnlgcP+%`0gZyj+$FSjj$6;c_mX83HUKDr&a(<@+ z8G0{7aAyDBW!vmwZW*7vgY(tk#w&I;WiZ8+x}oQ{XcJxL0iUQ4r|7`{@6Bt_vuAzl%4^L z@ZR4XNgIg$y(1KO8=@YBx4~p$thHoG3AQ+Vxzv25IDQfb^H9{Sz!0r|lr@`HNO2GV zE_FW0zpnvkBUdusZkeS->O`CL4KAx=c-@JwZS3JiWwEs5T;3U7S)E_LaA0rcnLJA@ zKxzo_KH1U2*bq8CutI-(cFnWtS0EACQ^q^8o!?RSRA;GEpp6PTxAhPCCJUZXZV?4M zv9QS|3c<&rcZ`e1dF2uK9G7fqW8kab=pnne1_!m$1|>&K>>82#3yNG@G@9=b)=JT_ z%P?G(d`A}b0JvL`E_4WI4gv4s07W+RAKIYlK^UXp1nN(mg5r;uy`Oz_?$RT*)Np*2 z@j?RE>k{FNjAnWn$yj%+%h`FU>+*6-BHeZv(@uRPCg4~+U@h`QPLM|%_t@9S&f1`zPFN{+b6NH z1utRYeXEHt&}>RwUDS6%TmzKxD?I$5=HCn25a)i5Sj6L`n)zOj!HZ4)wzKJ~kEtJO;a zLUj^4-=s;Q__;?TFroP%xvTKwkF!)a)bcY1q=|nysT*GaEptsvxeStXFx#hRO;1yf zqMzr}l#Vo1W6ECL{wndm04(2)&E^N78A}~SiP|0!{hXaMio)SM8~9%TlQg&Xi{oa1&_9TvHZ?%cgv(YbdH7bYO$v4 z|4)03{c_7BwCQ-aWo3*xAEhs?0sREzN?A6IlBX6r4fg;|!Cwl&SmYF! zTE~LL?#=i2~vl@#*uzArI2YRdMGUeNyMuJ7D?&m^VPa10^ zgsiA=tn>2P?^(4#=Q-_UaAkz$+V@WYY?7kE1papmKvs)=_JT*c8n=5LLXF@HH+WwJcJHpumPuj87fke%>( zTkDCq=)K{A2W%4-^RvE(95&a=_XY1@j^KGm)_?13hl<@u+vZ#T5h z!}5vlr-}#pKYfq}BVJ2ER^QsMZ?k!=(6ErDk=(7+!Z0h~h;xiXO{@wi(18P5hY187 z&V~Dw59!}2pI7B%F!KmX^WLOd7fM^x*ss4xe9aP8hhf`35k#3pO?f)5;{6_DNx9zS zjwhg8_IHEFU^C~7MyX}uQWXSGoSeQKS4 zOTBMUwGD?{Zvlr8p=I^}>CTORn%lw`(gYxYFjMkpGFvN2jQ;l{;sX+g+OTlu-t6LD z_+WH;^^7-gNcuJS=yF}PW|21+yly8;<7C>;PVLW@0mht0T#J%7B4FvF*s2cX;s}~F zz53Y%Lq78tVr!nZ)0Jm7f2b)hxg+InKXi~zyFdRsfJ`b{sND$Kjam$;Lp~9^n_N0= zgKd@We+lDEIfDaTuDDz{g%0_qE8PH@tk{49`IfD`EI)VmH{8c>O4{abr*3IJRXKe| z0-00?tD0zrWnrg$^UlgjV(BVTGjP4%Ab=ku01Sp=0H`vD#1Bwo(I@x)-H!3Fg zgb(SGYg0EGk*>J(4nIp&a>qIc>M0xfq;?KASI+7)!dk}s(O87MA7$X}J;4Vvv3QX; zbtlpsc>o4mM&buN6a`)I#HrqvTHA5R{5%9a{!3h+Z=HiB8G>Tl?2qvioUKOoHG)d=Ih|pUw~>H1zYp zpRRzL`_;=*2G9FZS=K0o0?r#W$)8( ztnE3sx^hl;n2YD{v@7p&`J}M;66@${AGT! zP&2wAW-3k+Xii>QdMd z{rqw%Y0P~OWK+`gk7cA`987lr!ZKg6PRlxX7W3hw#2y790Yg>G3ifh#v|lTiB!tWm zyQO2$%yAW7!qNjHy#fMH&^Lp(1ecj=x=JBAdjKF%y-fIyfR$>Dxb6+ta^T=oBv!Fx zt2d_l>zqx7SB1LdmcbN45V08z#JPJ?Ayp}=cwQfGgvBS3o|Vc64Lyhjrz#B)vGT+x ze%`WG7)yv_T9Z}9)!w_oI{P;9`R*-6F&${{!t*je=-MNG%QHhLu7)+=!Q?Sfjzl?` z`Q&@!!Fx}xQBVyD*pqhD06#%7bYZ;Nnp*6MIJlwx{cpBhvxr^5(> z$*)|l4&F3=yiCgSZ8@|@L4o@y5VbkvV6>e{3i9ZvNLBQ+KjHc_|CFFsn7_qR!whI_ z8T&Wi9yd6~nUZQA(N9j?k2_POP_=V=ZnaUB< z*pdo3At#J=&j_Icxx{ZT)Jy!!%};{hcrg3b%KV?6>k-F(zeM&2l}+%&J~y4jCXNF( z!_M7SO-?k!!-t-4x&eX8dZjXM2I(w0hm(mgjX9$fvw zpI;a2Q>eP)0k)c zozlf0rC!@@nI$7Jiu6KxW+Tg`d!O16$yc6zVJ0yYDn^yPQi2*UUpMs(rT5)dN%Rj< zO~*fIWbe(}$?@PR+Zj6g{ylBuB|4A5`5_M~vwkp0d`;!~BXtbP$CdH$XK8;KI#f`q zjAHFVzQCdQ!Ru}#@FuHoNeuwG_O1R8vs4n}o5a}`Z{lWFf_uhj=Troc#s&o=F|M#< zNL{oYOJ6Vw19{S1uoKk+qzTTDSK)6|#fyon(5gr&;)^bav~9pMKJ* z+26WfV_v<+jkE}`QH!+?^&fh&Xv{z1MS}c&Uu`Tso9~49FY=HICbVfDpE`tt>xBDR z@e|wE9_==6<(-!}Tpya8B*0rASUVnK>Pkg}nnQOAqfSZy@gc5wTj3C~BO=BFG#WB( zw__3T@eo&(seBC+9>W`d0yyNAa;&KByq=yTNc%!rz!iKfkP$DghFERf`Y=?P&h&`B z4lyM{U&aIJt_bN3K#KiDKUDhuH};`fvU=>k{o~OcZH!jUld)(Ucb16X5*{ykI~`aU zA%niP`++6VOy;=_*I9a^MpHMdUo#cSK3`Gb$-|q#dzg~J<)IJ=N`4%+W%&O1D#Pjz zb|F!u=eUU48l`wot0_5SG#vl^Y)}ZU803lP;G%WV2P{!dGdeRTy~;?n((>3>4w&(8 z)*=t1eGa!)(8!%x;}Wn)^1l}lLxib0f0XG#cIZPd^fC8`oSfA0ca)8r4CjIbPG#>2 zq8zh{oYqfCYMBIyxbJX?=G?Nf(++_=}_2FBdQ;{@c&$4=J<%gDe+8fS>*g z17~{~ItwGn(8>j)^M2HF85A{eheRgm2TFrp@@C?Ykj0dp6$cq%<1ghY7)7P&FJ6Y9uRs?raN({$uwc*EeYtqB< zh1&NXA}0wuKMXQ_IZdrRCsg_C5UbE#K~vIX)HeEG=P0FV2uKyovhAiQCr`Ln-vfnq zqfNe(F-ZT7ewQ42NzHWckb?T8Gk)aw$c-vvZa6^o%6H`7@uq~sMb%|>>5m0yv5609y0p*Y`%EERSK0?#u#hU@z(6z!HFc4x(tHsO;c~FqdbKF z1;xWHg7OYU#bT{eF%*=H(-Um)*Io7M{u#!3n=z-v>Wxxg>8hp1M zI6@^b>ndz+nl8(k*(RM5)BI1bji|TUmL)Y~?7CF;=u2t%ZpdNLB^}qH!=jVM3Z~J{ z$RH%oTXDQd@-+;P&7GgzztPfOGYX07RK>ey5%M2TaM529b^t;t64JYu5ck5e7C8%o z<~kj*ieqe^unZGSl|y`&V@5qv1tk9(uF`f(=FsXY-l$}|$AE|;w*IX$-r0kxt;<{M=&X0 z%H_?Wk;fhwCIDLm&g$w8+yde`5nVKKs;s*e#&Myi%Txdl%deu9Q84LD;KF`r9JJ0iRTB3m!BA!Y0fMzo_&3wfQ{6A^fq_ za3Tf1;WeAPptN~lZBvLFbbW|yd6D&9AIHvjMWP1iB0aT-!X57L{kmz#A=3;60#7^q zkELEjZin3rHrX_^5P*T-l}1-b5`cxU-e=%UxB|3Tvay?r$)M1Nrh4buJ%!^Ys^Z1J z$HBvKw8~7}i*vQCO^OL1R!3Ux`IR7AlQc(c-!@o$9n}u^|Q>Z6P}2mf!#4~C_`)*JZV=sr=$;1^fT|P7 z8oqc-#}e8Hzp3*VZv;*gyOCMjqd$uT7M=oNPW6yJ-=kMlVLOV$zlMdGUuotO>QbUZn@VZS?$a+IiiT=xUS_!lCMkyGP{o zwQwLox>8IXTHzDLel_CvmC9i^>!jyhFe3hDnp>WWPI^A3nRu#DHTl@-+gsVhwv{+J zrM9TAo86XpU2Q7IFcWX{Ny@f=;x0oc(#RWc(fF;Pe1cx+FN z;C62!vS-1ALX@%JxMO-79W1ifm$z`G>6LGN+iY3DXqusqMyB||_HoYwLNvSY+xtHs z*=YPJ#?~?9f2HJ9=#p-Dmp+DD>n7Kkff13KFSNT}J-nkWA#W6``X)qTA+%hjDAat; z_bWIOvcdqkwORd=anx91n{{#d8xOnp4N*~k8t3({0>cl%8?{=pu|vcka9sVdf-bU| z=GsPh5^yn@Zxw&KtOv~WU_wUCDB!8lRtWb7$HdWJ{+wS{bK<$FY?3tT^S$9w&Hb_X zpQL~KEhPvXAb*BqDu)-e23#IkWqwQyATRwYzE0RhgZvHiqcOnYbc1oT1Oz7hh6OWB zBXE(~Pu77<@jp>%46o?$7rwGAVanK?e%mRdmehzyrp?bmMe%@9(hG-WX278nvn-=$ zxi`(>nLuharyB`A>Mj#E_h;90ICw#0#5z7D8JY#HN`P)H1^m-S+?6bV_EY+KyVCL55^=u}loT%Ug?s-vDOfdj;Rs%JBuoh}@Uw}dHY(}y<9%?YvBCNu>>(DID~j*DH>$ZYv{c@*Rvd5ypqf= z94b9%c$Re)vw$EFzmD$PW}W_+Nfm{wVXGt}m0;xr%#giw#;g8E}Z6T73A?T`sGh={)7t43+f10!}9iXd$oUy941wujGA{!kGSi$G<9#3ZV}qu}v!X2J$Ud(z;r2+X-J(mp;8Z4e!p|R+)Yr z|CLSH^YmF&L&LyDxnze=IS6xSTi0SMvlQA{dk(fHP;kj5u&x}`o;=;-xXImAElyKB zB~)@X58HKQW1^RXkM?jN5FN{1=6Y;gy_A%g@9XUyU-wbl0r{AcJjkaE_q&UO{?JYc zwwBp1#~RN|T8mTqrR`LDGKDYN^9#tmSSjvb%pGCD*X>%cOCkT|c+(v_$us`Mk0c0h zPzj#pFIhUv7jqeXMAwaOMhvbo}qP^UG*JF~5 zZ$=+&Z7k8GXLlmm!-;$r&+8j3>Va71x)BV=tzo5(1))p)Dd@$0^SU6M4#=G>?fC$$ z!8p{up1i;EDIc zm0HV}dC%G)(uyYVulA4cd#H90C-(Yg73fXJIJ`2ff)BTpsE4Mp{}J95HyAUmpx!+! z*&N)A+-#)}#x2P*x_NXMro>-$WS79pm}CUsf4}r#qTDqE3R zD<|O*#Wc$}yNzc>0MfQ+@1Rlr2W_UX=rd%PiOJ=`K5)(iH7Wq#S1zPGqaddLWr?cr znqY0p*kMbBKsxDL!8!AYRI3d&35OYh7OC)RlRTn=pG_@XATtm&P6935W9f9QrTizE zXtCAY(B-|`!BVRlbDV7RcMWG^X>T_vO0z@?Tw-Wfz0vVv!h6`U9Ok=Yt4HlA9t~ZE zOVa?t0S62T)CBB^%B;+H#qlj$A0%=p+dvVKXrsKnPj=vs?Ym3hfhZ5_zp8+q^sw#x zigQ&v-WM}_5uIsj9wmy#O8s5x_}y$1>#~7s-jf~I0d9Td@hJqdEgiFpyg6-G0B23@ zEpZZ(j14v)#Z>Jc4(43~+5Lj|N3ULe&EN*@6Ck=*)co*G3`o5;dV$b-*C!m6Js{mv zL0)QWVS42sA}hf?0Q?4NX`WRT3(=9MmHJIw=^E_-jg7mDVSV_P^_8BWXV%nnpJ5Es z%Mg-6d|q{}!J)31{yx0IbCCN}2i)S#acGm@nSm%MNgq$9^o;a#JX%O`4i}8~gj$z0 zEIWsTg!V6nDJ@=}xYf&G<9wJ!W4#aGwYUd20^OSXZTJX`xXBR?hEIh%U&jKAP(=KC zjQ`qQRC*ea+(aQ+>8J*yp!_BW0yXkHuH=I}r1m`0lvCn;gQ z{f4Vh*$1QyqC%~kt-E(}BZs6aYQH9I8u(|Qu1Z|P#*N5wt@;WnKo?>~*XnCY9QD80 z79Lucv~p=opWQ}f|A6!mV?)E(-3;S2$yqo2P)3=ZmxbDOKb?;5^~;dRci~*U(sJGo zJ2$H?!C`qCC@BW59JTwQ?1o`Xc0@x?vNC~@;O1@77J^6WJm2>6tcBi#@7oe#8_Y}l z>TQ@y4=5w_spH~A8XS2M={Zhlne@zWaEUfDmgoqwG#>}`RgPhY0-ey+&6S~DES=H5S>Mn$=L}Gh}O~gLr^^}WF5%*r_+;Es18OAm{!Ezit-qiH&}7_MY&<>WubX=ftoO`3nW)>4 zH0($1%C00#w0(32Z_~8Q@Ka!Z(S4Sk??j$3G2>8C$Kk{&;7Ia;al>xHTxru2FBXxJ zXZoRl9t4JKL=ffkf;*QQ^Sk#A0yfaC+?QE0TvNJuqiy2?=`j)B$a`9fX}k#o(5IGk z7yZB#rfclcX(#}omN;n$yqV_m3AF>^=lhv`??3uJ9gsNnKlZDfjK)a>Dv5P(O>z`e zx{pc%_mN=Tbg2|0z?Fjg!zXo*UW1ft$7F_0F*a#UQjT?p6x&dw|1p_z&pNpNtI@5E z&2i=<1dQ7Oy%b*L8#~Q}Gl&y_hf|A>ul=9E1|Cv3+z_-(LwY?P)lb;CRE*3d5-F!w zN||7m%86V>F4i={fUNUT}Z4?b6j#& z5;SKNzctdEE<3LoOxVhA!_z*Oca_7v|IZ_u;7IS#&SbF(e(?B3UChO^L#mSQME%W; zrX%8*br3LUnqcmT-~Q7NexuSMaK~8{9@8ZtSo;=k>VRU8q_^z@6Z+dV!r^7fe*-Iy zkv>j@!!`au@1c8reHa*l@q|}{UZ1z4XITN9%)VXEXhFAGKRNL5M?M{FK0aGBi)mVN z_o`LXH0c*Jv(NjEgkj;}4^oiJUgr!d^4?nRiNf~%I%)QLOcbuf?l+H5m)ZHWZ9&%W z5$<9)tp0mKqBVVqjq53^YVnYx(&`9JFmsU-~wR7_sY{b3{Z8jk zJ_G2WML>$z<9Q`ONMYjh=ICCylutroWsaMCxQ2<6ZT5`?;7H;U(hRkPUY257pUW)& zPHszeEYNyex8FfdBFAvT(F1|rxk_rt#7lIqbSM8_Vu#_si5k8pp?5TLFTHP4q~8?~ z(Iz9GhW=oYblk^tc^<0PAb7m4g1%X7B#tBOk>ED$!2g9hA_XvNblTD(Hc>XIashn( z!Q_NgY9?}=-{jU+ebL~~)vPGu6DSPr87}M}c`AW?yJxo7es&;#REjw<(Rjz~BSrOG zy`$H?gdszI&}pvv<1GF(;0>8M!82?&yjLl`w$YeYKwnS$_r0)vGitP!K4>-!$x0T^ zyp35X%aNz%x<^<=*b}r{-oDLfSXAL6(pz`vnN0DM{w}xpl~U0B4$WcA!gzb%Z59kP zn{ZvvaLjC&V3tI6T1~{HZ6@UGiEpZCA6U&%{ytfJGq}$A#QvL&E?uT_EicO#e;v12 zX_n>Rfe##CP6LYaOOn!EBtA7SV&65@%YBK=8Y#{GTwR;1Hwa&m3g>jMm;aB!ejd8^ z08RH)7`oXK_@&@S2~v=4X{5*%(Cph~rKcq>e){~MNjN*kr&n=Ghpu#q!vW)!8KzCg zHXU8!M;9O~ov5i}!4QA-NH+Hlyk|bU)-KeSdOI1~T%@+&%wlrQWJ3Bf%L9U4h|at1 z^ZQ^0ac&Zw9E-LoDB36!19_;+8y5J`VBa5@M9;FgABtkt^}hr{&Q~)gLQyPmhfOc$ z0+#t;Ad|&kK4!QI%9LpGwSBk)YIEQ3`UH-lEf2B-LqWda6lVI@tFM@Uy?!`*Vq``{?Iv#ABu4DUv^4z{S9EUC}u>lLLii6Y<<9ZuH^yy zfLIt>DSIO<2Q(g==8tUb@6gdC^k&iUGMxi{EI;C7jN7PNZN6wY_|oAkMW2U^i(7GJ zdZ=bIffL_wgRP%GX1nq~aFRGqU|WO*(9$10!-PwmYDAYfNF?cyiE;>^DX#wu(*$Du zBs?&*K1ZEJ@UbB!&N6DRfmatb4!QW+w^PX&_B)K!7LYsueuurP@RA)Bj@zh%|M$Za z92^xpwZb@~$I7w*@(;DGN+*zT!ZZ6?-XDMHZ2~)m29E!09@oK3OQi>RTyl@)Zt1`P z*~25w*Uui&y}88=;&#`LK@9J|tKcyLLvC^USt#$zGbs-SeUsPvF55H*eHAI!WqW{A zzn?9NZ2l`a&R?s30@1h+MR8V7Wt5d%!=f-jg+n~2u8kc1FK8$S$G^(NS-lQ*6I484z~9H47jWY{Ot zgABQ0XU1sD`*TBAS+{5k&hmx^3*n6aVbPCJs7KL47k9W`&DVdtse-{Usv$bs1q;J8 z;gFJq{}k7;pn1ZxGc5ca)Z(;;r(Ti(Pg058mB5}38Ocy~|BQSz4fax!k|0`Z(6$6r z#pIXgmRJgViThlWfU|7lXn~g)hmN!QOGv$2Geikgk^oLpQHj_7LYqKnXQ@+H`hz{T zMbFM6WCue;{Tl%pH3ZjNfbl`bk0Q3PNHz(E?pE(!Vr!~&?)F7$$g z3K6S~a=}i}x!QTn2>Yu_)3dyPW?8p83r1i@Ey!keeSymVHF`4~Y(bM$9(6)gdmw}s zFBaP^IAHNf1||Qqcyo=un!>W&?5qT2?M?4UzyqKIDCx=w78=L;rbEdt?-uE?BzN|3 zT>TmLgB9>pTyG|j=-94j3~3>vdQ0$0qXtmt9wo1_+_=E`e-v!@ONq5yBK_N%_9SY= zgTthxy^9!P)3nK?iOm{N$$RXODcr1W#N#-FW=S9*$1@ zSZfT;RhE{s&D@QYE10DqX=kTFj8skN~y=(>%k0QangcD-1{ z#To?d;+qd-L6a$Tt8*|~vFtVU1v$){Y?xxD-7e;4?9dUe4zMEh^cw zgi`m{w@1_*{?Z=)B5NCkNcKWt@rB=q3~99JO*0qA*|4e=4jH!BoDTQ7387Idc7*a! zJ4725<;O21o6!`{M?IbY&?hyBZ|h{$_@-_d$onNGugxq-Fm@Hn@+f`deH1eab~V}% zdowo8?K-&fLIGO&ajV@JyUA)ICw5QCdK!fe_GvHB^a{n4Y)0LCZ7ub9j<+yNG_>@; zE*^{@@NptOjz(~%V~KlchRr!{uz7(W2#@y_NNE;xkFSn>S6pWZ40iqrMZ%4lo+^9B z^)bBGQSwnnO%CP@q6%4}45Lmv9(m$_bO;Jt-))hnW>PxinD_TB&d*v9)w8IcDQm>2 zqED|KmbiO!=X?dwGvx-V(VGB{vcAI~MHER?&3WvQD*mjZEgUgT?do3y6+qllPKf;J zW?aisy3o`FJ}1~4>3-K9$CFHBCQZzcCyEq}q-3Mq--x79-JM3`#tsUr44VCSYq_W%z3N+s%_pZ^}leaPAVDrOTgb1QG>1MC~5%nFRHLP9ioODj{(jHg^a*ac0aY zEF)_0yj-GO4Vv*$hBuil>dT1_ud{uD@&cZrJ!n4${k2RiPNy-trQ?KN-`C7xpm95F zQ`Ro>wH-d;ujYHoMJ@KJyt34f&n-L_`qx#q6)8`g{D+L`?)MT&>;9q#$%+MJK<0#k z+ewQ6bj#d$)RyM4eQQ4wTECxuEa7kJT8)%F zBs6X$J8}WXW@feJ!C`8(!WW*+F~&JUg?8+w!kJYV;Llg7I`PriyJPkb z1%*?T;!q1D;NYkd;4G+-MET%iCK^<=-r@TDmme08q2e34*jI5nmtMWtI4`&tPsxro z|E$Whqn{vDvkPi4r0bhvw|m&7>4f|X{sW*6^781Pg-hNQTX>cA;@aBS?uIHhN63%EmRRSF!LFqSCudWk-0$=qF zLd26d?|D=MqH6l>h;Oy?zU@l|*i?#O;E2nYfa zQX(;cNJ>f93?U#Y{gsyPmhK^x5T&GRKstvS2Byx#@4ue&cCU5zo4xir=f3ajbH)9j zWk{f%5^R@|JGMiwp@?L9w2m9NYT`#je?)nrs&e^f@xvP(o%njAD_9q2=Vrs1oB9Bs zLHiS(H&5EXNjy>7jsGXdyUV)L*2O(6 z>cOO)@c>MtXrCm2mq33WlIUzeXli11VpTiqhzrtrmUA7AOHk3@! z8hK^)3Vm=IOXSJ%A@bM+%N|NckC^umxr;a@*R7HqvwlSdW7m%gE9uhlmUch5)Lw?)Os{oo z&*B$(-i-3-Vn+X~YtPGOvTqO$%RasVlG_CAwJAxX)%ms4+qF|qxL|EF(;IvrRLx{n zA6-2iwaSsPjre@shoyP33n}R7WY49bnQF0(eVn@*cKV)sxQD_b(x&jeI#b|NV9eap zkcxU3J(!Iq{$&A7nYJ-kcfqio89EAL+pT1=RQVNuaU7ve=yNOq#^9DM>j}30>3CFq ziN(xG?#wA5VXKskx1^^0Pk9NwkC5H!>1#g-nD(e3Ck;tU_&1o1Sia2ZTW9QXt};HU zdeu*Uj_aDf)h@nO%Ke2D9NKLaI-WGUtrmSsBS*qgRn;0LQD6BT9k^;;#n}fS-61Ib9_sSIL!}>Yy)D?JjR^q#rM7L? z=;qvYzC=0%M7CmDyrMo^NgvDefWLtQZ*$a)gY_cBoWq8xj`6ZC zKCvfIe@gov8babGo;O34)|P>o-i|209ux(y$lD`!JAPY%{?g!ZP+(h=d*BQ}E;3N; z%I74AFY_&&gsR7u{>`P22y`A(o+)3Sq-5Q}QL#Aw=>mL>r(D3Wf1gk?%kggr*2}ji zt2lCB2)gjJN8l`p7?$gfvwDiUUk!%gR-2fsUbhiLAPakqkf08&bj>E`OBS$0<8e;f z2e_N*E3oGFE$iLpy~!{{&Qr5&C09ob8cZlG zsgbS(dJnwaGZk=c)?SO8krnC;p>F?*zwc7Avphjlom}@BnHegQ->QCvVrOa7AbMqCk+s{RTDZme9iUy|dy47NJqN z&qq~GU5_WvzO6zI!`TvED3w=eyR>>sO_ z^_pire1i?Epwx<=#l!qAH`98!&&5~F5kCw3T*CW_s6R?&^Ga9|8_)dmkG2)KB+n5` z&*DRQab2Zs`?CKI$Fvd4ZU88D6kQM5irJ-KCj3=FWrly@!H~m0k%;)_wUE8+a-u|h z7!+V}_t27~ycxgkAmmUSKNHSABEH$zCDJq9+2&~73n3q0ZgNT@gi3>_`^%i@! zg|i)TAM!7ikcqk=hLU%~@ZN3d8dcAAoo`x}z97}!4)XQFJ2{lIUof}H3(&!D?7PP+ zisS9)PNaSboi6n7wgXc#&IA4zHdwRYxiv-CpZ9{BdYED&eoY*Ii8;3StRQT2MTuvn z1Do$mYOw+%)giBXUyli{04w3vZiYThznckEg}T-z(rqSdM+f@|z@d#paev@v!Cz(9 zC^!ZqNcTQP?24aYIT-N3&qMy$57`B-PQLdnYq|4F4Ed*&JzDOLv!Jy^wFQ|rnrSYV zGp%}xrkvj3xKXYj_PIm`XW5)POC5Vp_s50rw9&=I2oXEz4RN}9^gdi%S$RNLQ_r`b zm#hxGt`TDHu&t*%twaWt2R*OX?sR*%P39~jShDn|r{)czS(WhF_nd?xBnJtAIi9AD zxy@lI2TIo`VEQyhZ!{dl{{%kajVh@o444x-CDKExN^7CC?>I14!cX_pzR5$bVy%!% zG98Fkc@sAC-ph2qhO+@3n_;5O@M!aK_S_U)n-AcQ3$BDkqe)UkpJ+JsP!?|#LkO!h zPffcOMhRyrQPxMPd3!4R{f%BX79+bGNo6qn&m_XkTqL|Zul|dl>w#MAh~`LI1OD7$ zR;C(U{3~*R{Nv~Ah4ev0oN1s)sj!UWZ?J=T#Y8UQhTQ&yy|xWQ;}GfH_$`&W3CVo2 zTGL$Y6Uy%@Kh{5^RmsA;G+rno^ZmY=Bx2Q$rue32{MP@DYKIx@u(uvr)SlkxznyrQ zCTXk3bOQAR0%1L9;9o((8-fTtP_k5g_LjV?KO6_?i(3!tj%+izO{f0VL2GXB_kL{F z9Dt;)mK35CAk^ldVU+=wo_mX)IF`qyA5jFuxe8U<2kc@EO+gLK&pvVHm5!ne@^N7Z znf)=zEY1cVKR$CDp=Wpw<(z%4v6NYIG`Idqu*FV%#26tsgD4v93qq!@yi^R{6ZNY< zR=z9<+oO&*Na`Wt=L(k(E<$b^SbtN|A7fEFc-rg8;~n*>K=QwF3*}}~^H32XFOkrp zQ`1lTJ1~IaqGOAp2mmULr#fojXXo|-Z~&|496nTY<`eJ#w9NjB$1U*>od;UwVg)EN zp`c3Cb>+_quW9;{fkdtky^I!FCe5?tQw(7uY`eII?}V-=n)y77P(J7xpyaXd`b$h- zBX3L7-Yxj;C%OcbkYf(|q*cV>j|`W1ls045Z~U#>0h&6c_nsC_jpm*L52B#4To;z~ z80_fvLHZl|)G05GU-tD1;o%&2KckO+ig(mH1vD7m!#l&T!j`|MV+U|3twr`Q=KRID zwcBLsOs<``4 z8{#VownloghTMxDa=Op-)npFwSF4|8(-%kWXvKj|5aIaMQj|=`BOKoP=dp5J!xQ{5 z)_bbfg*mI4Z32gRC&IT-AeRggtHnLY)_~yCMPIvwrVNwStBLjV0ukD}BZym9_Gp=W}P_6MfPyK?%>BOGo z+^?LB)pO(LeI;mA7IH~A<*YydUj^xX5(Crl)H!oc^0bwQl}#}KQv2c zA06D;*(&hFn`PYNbtls=sb6j&O`MgVMZ^m3XxXqVvp~cAh~9rlnYW}kgW1=9G~(Bg zW0!eP*;$_2dnuF8_uAl~`M!V`otCs%ZE`e=@RtRq`}DP_I^UN?W=@~ZJhOPH2D`Oi zHl4g_pU?`I^6;oo+sFukmb&0IXkgbKkJpWY3nir65XKjS4huuV#{d zYyG`0aHuU#7=N&Ag@52REn;{_eryA+RbdT;s?`!ee#LY6clGA*a#uMf%6&!eW zQ_m`C1RjZB+=ilA#SRGBv^06r9pVXZFc3l+ExQ$_*H$a42}J6MLC#(O6Gwl(daT57 z(`gqLO=9@Z`61|OM-F+X9dnnOB{5;c znNi}_o?%w1+58-rw7 zN=1|A4F;F$k5YD#e+=y~l^P^(>>GVL%}Cb)%-v69#o##Hp?LoEg*s?PH>WXmI>$mW z0S$D&{d;y#Qx-=*()RA_dKUbq1MJxc2$!_1+ZAN2$_PiC1rW~!o$W3G`rMknBo3BV z;Aru>J?EtmB;=yUKPhNW0evGq?j>X%32Q6KGp=Z!Ycua<*19If8d4O96|i7m^Z z>eP5B`MxR^tsGLl-4e-qN%DH?cIlqEn%yj-&fJat%Sp&)-L=2=M>B;Hjst`1@wt`> zgKl-8w3hm-@AtwR=iYxz3_^xQWZq4$-dggC9v_7LBrftNvXpR+2tnsf%Hgr*P>Ob( zDqzTQ;8>?Bv1?O-yVChb+iZ1TJ594<-9kh4A+IV5A@z;^^IOI5rm$yupmt`$Q`=8d zyyI{?Q9I$haV%aYh1H|w9brN{!U}j0m~e6pbDWND?;)o?f*cG~QVgG_4Syh;DH#EN zKpiHhM`qNA!xSAU<5OzhQ8lHj2n^Yf`Uw1RO8ohmYo2kTPPf7NibLZ=_Ua8*b}OWL zd__E;VuQ2c*JD;=nB}sN>xC@k`o~<72i*}o+W*}#lbOGfP=#{SX+9k5jPP$%aDW){ z{pwq~Cq}s7fe4Ofqi36X1H;vb1?YC$9`h{yoCS*XkjSM>S~7|-vBdRo0=D(fDwhRm z_#V-;45z_mGCA?R!wLxe6%)&G^wACb>?&iC1eSYc*655SewdcG5^yeg=I^K>`yMyr zAKF`{jK_AV(fs$t38X(`i9MlR_&q9A>MP;&uq_(_5XntdPg%n*$UCX;{R+v0*}EZW zc?|ZK?PB{2)#>19L>1e6yimkfm7y4xu9)thZ@1DU5{@Q!Om3prE7b(HXrk3OG}+iq zKoLGT6Y3Vt-w|s{&IZ%ocM~hCu$#;{n9TEhTz`Nx}DH26%xrRv8pD05qA6wx}#Rs*!ZL*0{KZ+Fbp8cNs-hy@0~UN=Q%11 z0k+nlBzyhB`WO$DG5`<|b9sMu^(&?}<6myyr)!%vnq|WIx^_BjA+qdCQ^02)=>+AI zg@<1+VE+gYPsRx3pn8$DUJKIvhmnSNRH^&(f8CaEF$<6&u1R@ZWEwI|HP{}42BA+CT+U&??kwNd-p3xTTfVscp>y%S=$8=TFx^W{MmqN*v4ve^Ec4Dh)7JXv z)V+AX_B#4T>nILpW=pYZ#hxgPnRUI3Y5#*-C1%=45<1{i;&@-3i&^M&o;;~QxhFGK zAj58tjF?4X=(-v#VQ{FBKl_DWE2aj{h47B;N!J1_9?>$|EzzWm5NH(xc|7{u|M;AT z1$))I39>A-yxbvj^6isvnA}pXOft*KUR(?z0zniaq zejM&T+nUex2+#Aa{Esuf31r_vctzWt)eH1&Yt!+*6h0_DO`Lr!R+-w-S5}` z+JE2T&dRO(cz5TJGq_{=^9Pr>jAjHeJKS+siJ_hf{+tKra0RP=_3w%gxpX0>%9y3T zSN#H*rTn;Ud1$b9Pe9o-1&uyUs4=JJFGc_ zMB6mXq;DU`^K@fwiYJsj^!^d$ui8jFdxF~o+D1&L7wyx|1XaPQM!yRL2a^56s&I`kD z41lYK#!&>!u{D@rgAMCkv)jB#rRHQvc@mL+FJbS zG;`QeD9x?hL(xiDb4WA!n17Si=USX~pC0`yKWgWmkjdK@K~Uy=;>AxJp8gSDTLK;W zuz4}g@85#;+kd9zJ)a6Lx1)lxC}+qFke5AX^#{^=D`0D2JSJ5!;}GI0k1bUXECY6$ zq|ny`*UFXeo-o9HFZH-%`=5u*9f0_p`4*ATPU7XJd4v0+@>Al%zutF+sJ&J{mA-4} zF_-6$H&dh-IWZ=#0N;D=1hVTU_|1$hcj9gN3!QSBTd2PnDK71VXKPQ)9F>euJH%z? z--Pl1NODD-@mlwnVtsN)LU(5C(gv>f z&@E)b)b=E3u@;XmasETYEg`#qt5=Tg$p`oaj9T`Fu#GFt$Z|bHB zA$narzNd;~fFr+0;1OiUH)P0txeXp97fm-5+GM-hB z_yXRuC_#HovFPzAG#p9l&%N`6AI%rl(lOeuuFnfyvh8jrm;ei^ww+!jb_$jIBP zquDS1p2Q`>dDm9jL19pq$lw}UrLxk1P^A4CYO#eNWpiNGAs$wya?d#T4EGuQqBCOy zBtmvJF!)=giRjTJdPLQ zv!R|SE$TbXUM#&2l9OBiUKDRUIG}-(WVJ^6{4sKut-!!7d3gD+#9N()faAm$!x!m* ze$C1ju8Un2#CtDd4a7?v@leP(gZi5Vvq7XhT#?hg)n|wJM^kUW>zDd`s{shhZz7fb zg;(Au4L`s94S9N``0&OY5{3UzQOAM6CBH5F70VL_%klTe;tpS}-KcYn3W?~1pJNl} zP9ESD-f{!s%Xk^o<2nnToG9!7u=5*nK?u+CF5Cm;ZFr15mrz28rrhPVANJaJ{o(#v zErn)vP1kNtzOaAC zunbR_aA~oodf%;$x6CI`zvXaE#&YRKzbJPx^~U~*Xp)tCQx!tuH=uUbYv%QBiOV*j zHoWU{beilWo+_63p;G-Flz;0ng;I`V;2xGU3Do4|Hw~X#3&&-qScI&SeflHn+luU# zVJ3vsRd^^cM(f%Wl0KJh1Da6UNrj%?Ww^~n85Riu$av8@l56GSUvZrGUlo*6@)Yx5 z+tapk?LV+JO=`6rU18Uw7g`Oes->{fkujR5Vp}TqxE+4=Nml=>a$B{=o={y6j7Ay1Ya_oAY@4e3$noP3V`+;YY!j==n zn7y5#HG=w&YH-B2kyhw)H^sxgSkwq!)nh^xXUK<(y{1FfhK4={_bme z^~UyY>9TwGbhtTQ_J6jj6DM| zg~&e0ip2ehV@s7Z2N-gKr~H+1IUNH*cKq`^J~@{=YOxfbTOx0NWb2q)TqpAEz$%QHXX_=c z`Ji^i6t&pniqM=VsO;b^(_g}R z`bPUgt{r%({OHjCtZ(sQTSBJoRksqxXW9a&ww=$_{Wb;)ZXf+pqL=i<0=?za;m^v( zc(B`?=x!PXpyQybX%LB#IU~uK)T0Z~n_huMveIe9|M7W!_Z4|Qs2@c94n)SrJZdEN zkb6mxR0;tl61>;q^fdL|K5OKqtLIsf2b?Dle$_!u3HnHCK!Y9pb(8qf+y9xhcL0$~ zd43+Mcv7&l6h@{d!B^xteE^-+Eo=8a4#2}!pl}lplBt zGFF4eQtZNnXX0kx$ONy>_>mTYI}hIg%Z$0@SVR8w!oc@Dz)xMj;Z~`BqsuUSltCD? zA}&2>X?+Nle|DgwmJm>4{eA4$a3LI(ZQhH@-9Y<)=aL=QfF8}GAGCU}85t7wr`_KpaS&P1TWQe*haKR{xudpSU%J zDo2Rs?&DYe-?ix4=|a~9wHIH^?k^Iat6&d>&Z~@-nrB$ex$;ovvbA?}?{2Em74C}+ zKc3O&laxD@UFsi{@l<7F3z6B8y2QI{fz&3?{126N{rwh2U z+Z0r-*EsA+ZK)162|3FODMOuP0b+CJ;AxQZw!u$u9R5VyfxTqJ{*Y`ouKo}Wd#`7K zQa@iTcsjz<)i%rVnAWH)Ps;5a=7piQ0byFN`iI zYZutU7R52(TeA+KpIrBqUx|d_j(1b+$Ba;0MaZAYD?Iq6pBZJ_@1+k zcSb3lkcxksp7J70@aROM!t(I3HL@E1_h`0r5@El!UpS2+zj^%AVNph5Wx21!5olv< zt{x4CH}2-%xH&oFiL$@zouGj_CS=m{)IJ+E>1de)e#qv7aWM(0 zqj%aG_-L)9KXf(pxK$Jd#aarOuiNY++xw{LqEuJ8hF{R`4vi~2X0Jbx)e4hfd1O(% zBS-vlRC0-ag|~-4E#mp{x*8E#pvijOG2!wB`f>oJ&{K<)T`4O_X=dp}>p^c^ukY$r z2l(!5J4f@4Qu>9Bw^C_4+gAL(%c1u)DkLFVz{#}{&Z2+VKIy4erdeyhn&_#U+3$#3 ze5n&(%WKJB7pVKLr4GMJ^BeLfxW=$)?Kuq6-4b+k(^k)wtgY*e;m>%FUMrr%G;83uvdfB z{U^XZtcbKLb4$PpY;S_UHEwnZ3Dmj;4A;1x&~>Oy+lmo%^eUl!(6S$Xu(g7&n#T`& z2p=T5y0%No-C;Vg%ZbyQUoL%91Z2CHoBq6;$FQ`LaKp%Ts}nx7hnC=~7G)5NGTAIi zu!7RIG#o!(D+2Lzun0Yup*M@jmV)-uy)}vH8oK#HD0e|hAga+v^4Z5*!oS06Q@}#2 ze~HF=6L?K#GDBs9ot&z}G7%kQMs{&4sj1eaw9z%FQm8aj+A0L*SA zf(|y3R{`tl*HvFKvq~(Z42k2!<{eeesHI&fbcQzh zuf4c)y8r!CY$DmrmWsc8H+;+Tj;@Mg2g2r|ml} zV>!!Awhe%uq^@|4bwK<_I^7DtIXvUp@=c0_`EkW>GyPO7!FDYiRL;gr9nPuv)P0GH zE4q@&^MWrS7`!owFLSSLBO3qF`p)~8rge!wC#RzaLiYJ(N@QNNK`^tr1ta}RFkf|2 z(ge)@vupI8EX+H5UvU7krkGlM=(enwf7W8-?T4HE^f^vV!IAUr!W@F)(XSJHj~$)U zds}hD!yxX-6`s`-{{*WGLMgEcLiBHKOJKHNC|2WrYBm%K^o9KQdO-mw(ZGHdM_-cr zVKR``cu!-^{UV*><$&>@mE+H@(U%YMw5Dv6vSI>6J<|11?LzvzdWz*Ora&X5HV;z| zTNfuB!4CCY8ieM5j&*5v5qfw3b3WJDiHB3Pnr7o2b0lAAtX`2}E)8=5OZPjrkX<|D zy2I^d-y{z8&#JNWe@1##lv!|P?P%fF86D8W)Ap*?z^7T6JYuRNgUAB@!`g8eFCiON zTkXJd|WZ**$>Aexdb>XXMB*hStz|tuOwG#Y?Xf z^Ln;^{Vm_xEB9i`)pmgKoe<0U-nhx)>tM@XMfn7G*=fa5PT-0M_dJM5mVfJV$_ew0 zoc8n`Y&zvT>Egh7wN0S;uBl82cI}^!NuG4y`rNhgC~$DU!$=4Y&v&OzlLZjt?^?%7 z)OT$?jSVpHx*aXE2f7$=waZHgdU&>n_2#y&v~rY61q)bboq3VYb?gC{{guhL6TY$E zohT%bQ3(QePreR657bOQt02{BM>tu5+gLvSohw8!s0nrn-1hz*9%&QC7C<3!88qEx zN*rL$x`+UX)0(JXS7FbTcuT(Gr3H-YKbkyWGLW3QJ5vM90T<`-elow27!;5@6T}V~ zCm8sDDt~wHq4S4&yHsldDz4f1eg8KwHU==d{uxYima#Maec4yqGupX9q<}%f03TAM zku@DX+a1~(j>X_9_43%bH9&kXdsVIiXXsMyo$uTlT|}_1k;Phy4=t5D?d#P6!wi4* ztJtq+W%Q(dpTl-g^HO$RA5AX2B)R-7lIIEuv{iH+wSJ9=K4C^3d6Bkqj8rykQ_9&{Y; zzBNnpIc&T;H)`ao?L#Xq`&X#&FI`V+mKtW4MvR29HTl{PS^LX0(!RW%ySiE+Wc2oJ zh5~0Opfatf>(df$g=9ZR%DYq}2ID?z$LMb}>?!@Nd^67U8SGj(1MFMAB*8o!&whS) zuKt1GRX)DwL^d(hv;1r`6f0?sCIC<{Ao+#Nhf*wo^!0epJx_~cqQ>HD)9Vwz?G;0s zoaY)oXgu-wZLw5?UZ%sSVNG9m9#mQ}YmMhB0m2sH`k$4Ezyn7#20OV$E)k-A{s%Cd z+t@|oVFcLm=M5a#C7#&<7+8JZE*Fy&Pn#XV3g@}+A>9YDrqZu;hb zh(+QFvOo6nZ=#x^m3(E{>zcf4EP%|za)z$tIH6R>Rk*Zpz?d+{Zr1}mmy$kqr5+Hj zw7Y-7z-yvCMbs!1+FbF`@dUiEcEOxUIDz4w`8wyQ?{o@|hyMi;T+9FZ`4gO~22+uE z4)bn@oi+ilCJP@dT82Mf(q1zL8Su8 z7q(&<-T0W>8ILV1{;KeKKE|MGp({!JQ>t`e7qgOxAh895)T6;O{&T__z+1O}xP=6c zUX~_fk=Fo5D!jP72%ZP5WtHvV+wW#f5+w$^!O)A@G8a?~)!2n9IHvj~iRWUesi~J? zubaYHoo$0l<#;U`%`cTjIHZ`FvMUV0;*GxM0cSJWViM^)zK@mLz(*VIz#6)W@2pxf)k(zG(g1A$6U*geWVXur!w$ z^GQrhqiQ0!WFn+a5dW8uN>yJFlTtARFsH{@iArcqA3;3`vP7)Co8(|Y7XZI5XLnIt zb1GNw={jR>%aR0!Qpn!WeJDlE9phSWNyfkNI@Sd~i?LFVgNt8rd$`^nJfr4D#OEE) zThCf0qi1c!^!|&yY>Lu8sn}tOsM7#z0d2scQFM$4950gENB?sb4yC~Tc8<8a?D=}z zD`yd58+iW44w1F8c~6$lx9`5^jB!foT7GP89Ud-J_TuEMIbTXx0MxSYh>5_9Eshxt z%$>tSD5dJD%G6?Or6crbV4gU6oyXY^UCjzOhy;gBy^TJs1RNUVapI25l zvmu(W+@+LqXQbxO)GejluaEm*(|vd~b$12&5e^K-kUmwK6LMu||FoVfx|H@V?^X*D zE1cF-OiE>s-YJZz2h(t5ay}#e@7Rl(fn>F#j3rq%9o7d|Txe3>4FqvZFN$c49P~o= zl7$EW(R(ac`Y|A=RqA-K^Ey)dwxP+E83vo>M4|*UJ_is%!t?C>xXzzt?fVPcG=DCno#qw>ewU=z;LuYS-4kJ;;+ zXv{NoVauFiJNfke5K5Z_k$ZVzcY|d`JwoEKYtKy*g2n5>I$zE_q4e+m(k{Y}Y&R^O z7tDaUFOg?t23jPxG0k&frvwLR*~!C#Zs<3NbM<8rn&L8q%Yv&@2DS$C@GsJ<-ql_` zQ9E553L$y-*=sVnC%}KT=QnWDdXS>&NtIjNJ(PJFtH8KO0_y-5ig}t}5ir7c?#9|g zywOwGRWY}@jn~5GhJ)#N9srLuqN{9@#9v50c}~Fey;iHiD#IppR`qsFvZzW$4Pd4o z_j@&nC|TqZO3?>0cs?_pK>!Qv%kq@N&R-hx&W_eygI#;qh}rHEOMZySkbXZKQz$fk zO~g6JvUS$?VE*FE4x|LvUBJ><%Vg`q4g)G)8gLocljDsQCS;A7@~^tHF4Q?Z0UD zfKf1IVXLaku>p?=ps+y9Da|AgWefppCw{V5d7wSAXl9*@&2m`b}vxMHz4pLeDwB5dzQ6DbwOBRe= z;ox_Mf1%$!^-nRSI6X!YRFxi)@!zE>ChLfQZV%joFG+lGJ7*!7Vt346OVnIZoP^2x z9I{=r`WC0eIc^7mrj*AkqQSVucV}svfd{PsdQ86G&b8 zOX~Aq3A-m%Ks7O$!oYWY3f9(T`4cVi-cv69QS3sNwdbz&qw)=7aYg?JRj|qKzo3*sh{q7 z_^aTitVeLdE`5-l#T(cx-IGGjAJ&AWH`f0I8lE)$;=P?;F($6!r|0Jh8-QPfE+pY$ zhj=*vCq>1~A$&8kL1`pNtrgGiV4()z46F?5V!5oMzs^8CM;p|c+a(WBC8s8-6D;0d z@J!?)U~s%IicG%3E8p6^b|7d$mH#&-Dq}^kV0?dZc5vB0pR_Biu;_~N_4(VcDr8>! zmy1~7o%85b4fM#LmzUbkh|kFkt_3_pwh{n~E%*dwL%QE06lc>s4TYtf1TO0w+!S^f zpX5-+%0K(maDAjM6fI25oy7jNI-(%5G{Z=rX+jEfzY`($qwJsLE!W3Jrs~WuHF$L; zP$XB1+9~_q%zI84N~P6m$&nH3$+Jg&YQ#Q?$dVh_e77-0BrcpSU=>q5V5S&nC zej9PXRl3ZhpMn7)2z2bDQG-jMa=wy;mf#y_ARGS&(DbBD$15XGxQEYQPLOR*Hru>8SrI>d$5e!v{tY=6VD$IkXVAbG| z{i$ew`&TAu94m&Nf&(D|z+qQ(v6ksL;jnp6JuA1Ri)EPy`86Wv+BO#W0CrITQYD1{ zT4s+?f-{{U`{;k*3eJhVypa1)>-n}yhUf@7P$`lJ_`fF(i@OD0e@t^_b=wQnqMb3SodPqo z97~#4!z0RXv#4Jz_xb#p3zz4XqD`U-E@gxP@G9PEDq|z8N`@+dXWldN^)lQv@E2O> zNBOS%i#0U{IzF*=lD*|Gy7XQ)26ZQikpBYz(f+chrykkiZ2Yds7}$oqb&~%3ul7%> zz(5F&LxfU|{(78eEozQfWRr5hHj*NIP;L0>>sO2qs?lqt9Y2aid1jFFj|(0K{0?*+ ze@Iz%GJJ3CJyDD*Nx^ge$E`O2eZRwmt2#5`Ci&S=eR2u%yNm4xVh8Fd9hD@BU!E{1 z^_WNVm2y#eT&^!?M(JxX8K~m1Qr5izk+G0zu&_AU6FkNZ#nxld|OD z$+Vjqjh-J9fbHcF{hM4GkvMQbcT(rYoNcNTCh7cW#_=RPMvaJAytPYGFvHcDr{Q%L zZ|Vj)8`!DfUK|X;m0g;dsC=}Br6n+nuk$Rf-(n^)O;xB9NndC9;u|3TLE;H;vom@p z*5}0Oq--T*s9!>Zk_EPS!mIv@_)Zc-&0E)NTgNpb%hEi*a;FJU&G7JbN%?yMBj)>B zIJF8TsOVeF`8YWbb!anco0}iVe?>{utNK*wo`W6P!kU-PLpVuzf)yOfSoL)+0U2`8 zM=DD6MA)h31!GmWD}W8(cl*f4$Op365`lqUQSLrEzVEI_U%hpKfFaQ9p!+p}yu=9x z9>94Y9SlqcqYgFvz;yDIkvd9_#GhA?yBXF+f7gPM>=5)@8oy%_vUNqY*09;l0sV8lzK};JA}jC>50X%D{7gKj}G^7(_$Bd&wm+QIPSl{}u-wJFOGQlKPr$K+A0L ztIcPd2Jx5SPiQWTZff#GVw%~aK#zW(2aK_gzCVw8*{F}W-*X%jVZqc08)95vx@hA} zyW;#fb&XyN&vmVu1-)l|;4SD3Bk=pi;Ca(^GPl400R+~N{_@!ge;&zre}HiEf#0ZU zj(C+L{M$VyvpPcfGDxxqvF+H`1(IM+X8!OoVvhJG2;r1d`6ik`6u(YiP01-k6RsvH zSpzxeC8GoJI(%ikEjJKO?(Lp==vv7Ir4oHL0Zh9~e^4E$NVTGp^B^Yu{izvNp+UYz zq%+AF;yh@b9!|3J+qKdCjN<=X_R=@6?(uTah>_qS)W^OQY98=nn(}Cd^=W_%Pe2C? znNwN9eiR^TvybU5jr48&`$c(}_|lUP3=togZIGRrmU?0&6twBYdM2H&uaXfmRv=3a zMI8V8nk!IyR!QOy%N<)dmg~;QhN@7Yrat2*7 z#U3Qv28|=(R@cA}*aZ#m zKO)9ZQ9V(8if3myI)0^_a~x^ea~#@v2j9mUz8?TqH6&91-`2`-K#6fXjynGCBQP0= z@q&|Y_m+yyQ7^|m6~!>g{og(}1+67p$zMpDqsW@VQa0f8krXocv-$R&Qm^7Om0oj- z^LgEDyrvWeHV!I&+cNfVChPs;t`mGm)y^ix_OLEJt|w;4NVMJl7>gZBK5y=20zygt z0}&G!tlh{}!|DAjfBdy#;rYm-Lv5k!ybKxdZgyIuQqjEEBy+|hT^>AkDYP{G_L;gZ znNVX%>kLWSf@mxHAaXie`}r_%{PJ1OrD)$djPhq-*dYmZH~&U!W&CQ|LRi+t7fZ1- z9^@74-a&m)*=n&_&+IE8mm+?%js8xtPooRDGXB9N`>z^Q=If89-xAMvNz!%Q^E&nS zT!s-V18<%(Ob|g^1?PP2*?rR|JUa|RtkM+jKIG4@(O}hlEz$D*&hCn`L=^|+yc-{k zv{;fZM7<_g;~EKi?sIMW6ed}JPj0SV!W%m^H#|>TdY?GM!wYQpF+1vvHvgW-!FfVg z!*w{DdYDYrT78WG;l~mYhn4oGSM4`vw-Q}ZtQQqsgZ|)uwAkwi5`@7%DW`;&PU=kF zqi$!WhFVxlx5n=B`2}i?A&2{0M%B6HA6u}R5p5fBXuXr9D9zLz=7(#b?DsEDbW7SOhXtJAR^pU z;E+-AOXhhJHPPi#x>-X_Y?sg_G%e%Na+>V1dhA4ONm(`P#V5-NO#W6xU4)nege_!+ zfT@4+x3=T5jg*q<3A{)Ci098^dgZ+83cPYjT_EWAJ-grw*t)~*`HmtihT_kg(JB;> zmdj-({$2)uUaxhTVbK^*8Zl=?Ns^a_51wZ;0i{oElGfJA9yf6((H(MGH;e%>TGteg zMK9qNuki0_!P-a3F9&{n5C^&}I!@OlR>NgOr5w~9dg|k{zyRqfIuP{zL*5c<&bB(o zlMjrwwUV;7`Q=IYYW->k;(u|-bcWS3Dvl=|yhAHtNSfQGDB*sE)KRZKKI<)MSJdvX zjoLp}B#0BLK2#P>)4Ud^nuBb|hx+Y*A-rF&J9iwsN3InsILJinZogi7Y~uoY)|N)_ z25Xs-Ob&|8cH^g#$|m2L5ilkh^EVdscgzCHMek#3xEdx0bQ(b&)*(JUT|6ew)IH!m z!O~cd4rcK%@U=TfQcH~=48E6_UK40=*YuD5-FX8+zSZ8e&T{Smayk9w?B%-vV?e9+UlR-09PX=v^W%Wd^087}y@kHFpg0j(=mTeO-s z;9t;lt%dQA*p-MU@K@$+ETNc=9WD}UWHPgAY87@-kE2ZtPgX^XlK^qCKmU9L^LqJ5 z_%C|ROA=fihI0Blro3nsdz;pU*tI8ru=2GsdESr=_~+;IdYosUbO6NW00r?U15vWm z${6TR3*~>frOF^t-m82Kz!6x+w@(NnBQ`)>3)wtmzaiS;Cilx=Ns_jnFI+A1-e8L9 zkhrSum`&aVQDG>nib(k(_NG}4x2Kg4ravy4UcCHtb{!f6Od~3vBVXUDUUoFde*6Z; zy!$aufVP4H|4YB4`Mth)9uC)n?qC#adp?Bc!3Ue5AH{WF3iIB9kN;vZmvg-ry>wLF ze|qa@id&KGf1L=uS1p52y~;bNN>HXS|LyoHRA%-e&oWa>R{AA?`e7sB;=2UDCjWQ5 zXK5S{dC~{4-BPgWhF7@{e!BJIa>uWCR1JX)fFLHCqE-nnzrr+T^B=0+QhMK^n%JM}v67Kq~kMHn_t#Ij-!* z(XOcYFAuBSyj=I6?_F0GOr?lrSnj;}Ue=VCcSa)aNoXAl_a)5k%KO62y5oaal3Dk_ z%3iRbyKi26lofFxn)NGHnKpaa9Fj=rc&3NwTSvoYVDR&yps1#P3xqC#TD5uftKZ49r5E@5g331K zFO%XSs11fGvS-CJGqlu82@p%dgNKQx=l|}~8plfFNiH&df2&&^2{-}!lui$$tmW3z zttZjX3SwWo#B=$t@dtXYjrCm(S}$0#v(x@RlFm9Rina~oOCyamh;*Z%fPlm|sq zuWXwF4KQz-d`j<~*5|2Rl6AlF6F7hM-EzvHSN>Ev+1%-`wZ?Xg>+|3_zjx1~O6`M> zVVrl3F4jcqzO>mW)D|EEAaL8fgxcGGIZs+a@T&PbLiXy~`8Vomf*@*<-t|j*6M0f9 zE6yUDGvU*&9JXd#JY-0(BM`lVe*$sXj(}Jab+`uZ9nFMtC7lOMr#SUM;bsv$zcE^i zg&N5W<4-iwve2X+hO;q1z}X!ukx_=#jyExmPQp)MEm_!3`lImYtGKyp`pr%lC|2yF z{qPAjx8rL{Xb9G0lC6A7OlTSL8&Yy24-}Z2>v5C;GoD)L4Uw{wyVn4v{0ejOYOG2RvXo&P@7OU_p`kYQgux$ox=dT4H=7<*s97OVTcvxj!D{7f};Yx1x0Z zdq@h131=3$XvJV|q;587on5cCKGs1YW)Mk-*Byq8!DzQi3BVGtsF}X_2{pTjxx!a} zu?<_I%={xLx62;#tTe00A@OfQa1PyMyEORNbiSLTv+QAC)82(c{K)#N+bGa$=%)p& z$h<4uL{0{`hIotI#IlW4OGX}LDn9K+8hTy!M+TZNeT3%MQRXCy4dC9kiT2#}g{blC zp!WeYokKlf!*AUF`i{k>0)i2#6PIAJ!A$sK56}X}AMWDyr;|5{;>(1JevOq@E4&RX zmYmF4Ki6q8#R?wDYzr8H{}5TN!?nrHH!w+|c@bV*N7nsr-yRLEyuU%b`mdjsftYn2 z^#12=3i?R0Wg>XKDR=iQQ>8te=<3@?n1E}f(BhLO+wl9l%}W5{LVc)l4^BuheyD}D z)pPK)O1oAP=tjw&|F3wM^IUR`+N4)NE5i&%Nps18fG--^g32!`dW6K-xeY z3xD{K^$a<(LaN%+&~m%CQ4&JHTvzG$?HiIMedgCPcN+a8Wk$T-U%<1F!p8Kqpdp%5 zX5447MECE*%%W-_@jGGCtxI*#PMcz`nHp1-x3+`M;>Kj;ABp8ABGSuxKl*2GolEx) z6!Ts&+qfTag&=H_)}&zgr*hM~fbK|(4HN)Q@RUNF@@RbxXBD6qSoko!-t6C<1C^dzFL|{^!|z*#jl8C~69K^I8eNz{ z_PeiP7v=Is9cV1xeXgBV@M19&5lcj3m#b_oF8I2F!vFxY?epVkr;dAGs2C|>rH=G+bO*$3$_QRw1~FGE9b}YTgCr%?oE|TYAgPjp&5v$Y%jAg% zTSgwX28+K4TG1{Z2sE?`B@U)>^Q>00mx-$=YlucD{8D@JBDVO;=plpnqqbT1;&|mk zGwq&HKI9AMoQ|gkDb(=ETEqo0&CyOk9@+{+Vgo+ve^Fx)j$BGfzCC5VGFOhO)T+hh z=aEM8z6M3TrTE1E;31bK`3-nw*OxEhW-(HBda+H*-G_2BdpOoa3G_lex!mi>IQ1I5 zMM}?h{*d=c`#{voUI7R4Y{tXqzm}jlkg}txENaW--V^BJTgylm; zc&}#_R77uXdBKe-+t6=!1d+v|%VJ@*VQ!oQd_QYkj>HP_SH5+xA5!JYx>lm zXSmZS9}uMjCsikbzYLBDduc8 zvmsKoXVP1~Nu`f%po5u5oo;$BdXKV9!s|y7qAqBa=~4GyZT3E*cK=aQyxADKH!o$< ztBstGN?&7sk-R;9&?W^ysJjfB!U5c;g<*WkFBZjjOT7SEehV5N!c#Cpap^vEwF`yQ zMF9gpQC?R#6V=-WUd(En($jHY4Dy%*A7{)I-O0x8hg0Myk}MHBQv6O2+zI*_NPp4G z)xY!^Vc}xaiIlGTPK=Drm3=#yEG#A4zhnRfDbN@=jJ$jEe)ehjJaba^cdZxOc|4r_ zTZBp0dy?YbQ`oec{y_qUCR(B(1X**9E^#=(A7rM^mTL%o*pC0bBU}dgEW|q@kpDCa zmP;HmMIin{8szo6X!@6o;F}VCjbaey?j8jPv6|`xuxatH&~?b^!24f~ZD<#GB<|OV ztW6Llzl2Zts@ASf&Jfc7a$#x*LERx-$l(bbu@t{rdYudqHF`3gEGSNC39OG>7oQtO z#bz&x-}M^mK@r;ibVlMWd1emR@c70FF6bp-2y8^F&tJO1*U!?Jt?B#Uw zj+LbbxSPu_+l%(pc&qQ#F<;aky1BF^ zQ33NwZl`#eW%ro8OWE@mUr^s_H_FWaG*smeCZ#VIrW%+7DY}P!e1f|4?iAX_ytn z(H2KP20fF}!$V(2{o{edPeLuuHks7~v?aIqOh{Yu?E7~?W=e#M2iUEh2|L|}Pmo1~ z-dGv5`u5haa_hQ^Mrl1btbohpjY?~qYoY@6pvH1Ut&DJ zpgyvn*5$OE<-KT$JzIMzUK;QT6;69Nz1q&{BZbtRC=*%teWQcayEifPH}W9i2dQqE z`RfkXGNx?rakUhntQ$u=b!Q@Dk$?I}IWGE~klNFL*6P9cR1xtLs}gM(^C2} znNu5WjKGu`!5FK}Y4@%6F7(+ohU3Ba=ojq>aQS1er{_4xeS>LLq?_v1tgymKa;7VZ%1!dSpkq&>h}~RpCG51!?i>-B zd&;ZAQ8s8kk-<)Iy#Q=9oiICG%pZ1Bc&kb1B0p8}?EY&F4|LMl;MgSFG+^_RCydH& zCf_lDKG)w$44V4zuIl?lSO<|LF)0Ozlt$c}*~u)10Qeutw<+79FFzc&OTi`mxD5!v zf*D$yon~BT`#OnYG&NEY!b|d$L#m%?DiHV5^bJ3mqibvLOV>W1_E-9$PH>PEJOhehF!d8zc2GO&dEOGJ zz;rX~%tZ(&u-mny_PTN_gIxn$_S~g=Z;8LDt096*qM5vdt8#r$pGS2ehY1VvVUGf; zkZgzZ*W8eb#8zIDpb3cmKxB_#_l8yKi|=}>q(P$>(Z>$TFOps!%Mvt zjG9^=h_c`(|6Nj1{HDAK(n0CdT@?*a5BE6w;QJ0iUl%8W0e9-(3x?xpEn&zp@?rXL zAo>@~Qk$RYT2l187j( z8*Gtm2OYI@6u}jj0A--OO9`_fl=`%)HE;+Gg+UE9;i8ukH6v*Qjjn-4%dU&=iZlEF z)-=cqR&H|cL~P9ay$UnK_D$W>jZD5Hau-F)JoNPi0fE$9kd^{aq9pHs7VZ3OGyKcB z_W6?5UBh9&)lpkp;?5SfVMy_>=x0MPGlVIzp;DUm4EhJF{ZBjN-~QBB6p+3n?&4iv z-?O&_>8O(V^~j$OKAk%itk{kG`T8@%<0)q^ob>y4oa=yBaQ1wMR`gM%vL%o+E( zR|?X0>8E535o>gJO_}w=j;l4Olf?v zZ6S2X4v= zMfd0a?F0kVeJGZuur&8GtMH3e<`4*X7F}D^GuGUZ6QY*~o!BY4l0U|lY9LAmeHPIl zQ|8X!t})c+7|fY_Nq3az{QvruC5Cwq|dt>(x-#2qEm&qlDehJLO!b+ zySBigH|M-gDl1F*Sj~ao)y0c@YVx@JjHt7c$DfK^7@!%u3?1D^Fe%ORD7Md>YaXv= z2*KzG(!iO@8H{Tuerp9;p=)xG=iY3vByYQ@ZWX0yYduSGgBg2v#Bi7QT0Oye=Ob0} z--T>P*pvcwn3|gxP)8CU?lLeRwTd@0AFn;Dp2*gFI2|J@9>ke4u7^Hr0a{{zto3m1 zzz_^m@y$FJv8(5aQBjP^>*vmX)rJM-$E8}{@7v^OMetv(zgo+0V`dp+^kW)-sg&F$ zk=4grtsex9IIdhZ{w)-H7{5o%=O}e{OMLFtKc(1H(&u@8{C2BVxqj+3UfZPFX%Q({ zBJ!t`uE=dFNC|gb8XPN>@mMqE8iUg^T*OB8ja9@x{1R>&D z|MYD0EXBsev21$s2#1vb!>pvYc8H=91Vo~x}C01PxEM@ z;<1I>dFz?41qHgm$MBm$w%{XCvM4hevQY;jKMsgBIq0jxW6JgxQ>D*Dn#l}PmEG&4 zeJ^lSJjC&cf@rV*#|pH){c)GihdSZK>?%)Ol$~x2Vlsh9nUAlDNc>qFQ_MKAPc4L{ zG=}d_vw!S9%lCncofKXZM$#&uDNOjI;jt~C4!kA-jo-NFo3%afY(IkI`UMj)-ZJip zi+{6t#*k=wbTPWXEy=*I3A?2AqWlc#LCBqyboq2&@=qwz@pW$8bn+X=hjraE-QOZU z&=Rp@s(HS_EhZQW)P0tv_v z#4No+54t}Ud?RqtXDC*;(=L9m>dJSm=YUA$G0Lxj|DqS=A1=Q!G<-ufn_^g}OFaB` z#Ka!_;^Y1YpON!P@c8Szv)2xU&D}1PBJfc6-|pZ90GK}n#U?XQj~|rD^I&o#zTG(6 zoB9j*d3x#0Te7|nR}Bh_@bQjcXHqh9xCUzV4KaBe0fSViqLQMoQk}^CgR#*6}rGFdo)h zof}2FksDppzdZ@;ae>gQJ${2vKf!sfuZd}{L#4**J7suUS z7H-CmyOs8z6ZDjgS>*1KB2uOXG1~IArj}Hh4D5^LAR#%CNRUz8NT-j>KcYKjNS7$M z>`OiKKt!zi{Q1L{3w+));S&__`#=FgOd}EjdbX~nF!e>+)Lc(Ej zS(ZJ^#*Z2ks=leJLQC-O&u+zveEThtw({}jk_;aA1@8_A(RX{kg<>DpYQcLQ^NS8_ z2*9xZaQU{^QqY?lh9=rfoe{T`ntIp(+Ak%_VWwk^{fiNTsE5!NIVmu4e6#j_UgsDd zh0%MJQM~u_^RIFhcrw}@_&_h+^0eTw0p1hf1-Alvm{l}(BTu9Vb)saIep#YEuvwP) z=Za4jvE3&2V+B@z#m*ZsbJ_4?g~(yuR78a=DfymRSfr1jryc;joGbPi6Df;JNR7TU zxg^D^1O+kAEV;Ja1%p|>H2d;#6THqB+#gsE_D?|xttq}TU-APlReOj(k+M%Pbh}oQ zHP?k>sduBwG-y3}pOa>HcE6(4{n*rfp9E%E5#oC`S;KVoGWdX|LuCPDj;JM>r<6I< zJ2lh}hHbCdO2FIG$$9KKuG?QI9Kus0O)f72M}7NU`r%$3hSBBwxIti+U$3EAWct~} zA6#q72*RlTE^4hZ6LyMI6MWo?$M7r@Tab$020?N$pDj8ayIgK%0YtZmN&lG;8#^rR z`{!x_1?8kPvM#=iF2qu1$!xxiL#0x*Eyc;;ExQjg{|vSOni%3}XYfbjJTyY-P;F7VhPq ziDwW5ib{YtxuK2q@mU)qtpY$cZ=s;9ASv1KT+G#KmaqM?cvNertEf0&^R<6S?6ea^ zFv5aE$CgwB8lr<{j9P-wD83i6d)Li=(9{Rmf!T) zdqO%&0^1*Qe=x(t4n&;%`4e{eq$N~p<{Q7jVu*w24L#1uh*0Q-SV{3?tbZ{xnwMapt3|exEa;L}fT$%;v}l2- z?47n6BF!E-pr`G(o?c>~H+LPpp%p)3d~7n{=2%PeC!MEEWWMy@Fo$ca5}e6RW#Q53 z@;lTRa4GSE5G}LwlL>dJo^Ud)9EmgdH{xM1*mxxzAAjXXQJ|znrk}VVE|x`6H(L?M zy}Ada!13dE!&D8n&FR>+!<=LFn+v7_h^~ea6E%Feh@Iw{Zc-Wm--W6g@3S9> zKs;eqNn>@s&H+H}|0=)n*(X{xzMU5yaEluvsbAhGyoWG<=sFWo;QCg0LKn`*9^nQX zo>{Xu{xv_Rjau=zfUaOjY)$*)2!J8MVm*ZpMKE|W{PrEC^ykMev_5W|(~5jmj0jCn z0?If(T7)aJ{?*~9OmM9IGj|-*52BnunUyCU_p_v7{iqgx?W8v9Kz1VDn;|((Qf4Z% zA8e6N9Qf+&s0V+(?$kJ>CaLMz^1<4Dw&BEX<(xK-kv)53FrGkJ|MOdU_64oPyN1`4ki1blL>`Qg!ZOQpcyUDCNOT>F0xW3{(GV(g+G(=x=fPHP?@Z&;m$M0M>| z*qs!7sp{-FF1rN92R>`OZr8TShfAEIP~m_n9|6yaV{o#AaPAHe&!He%AsWj}FKY1d zm9J;a?G#R^ezd}Hycdpa#Y@`TQ7l}>XqW5Lu`Inc+5U2mAR_e_Lvp`kV(mzVoo?Xp zwf#js9H^VP=hr7%I0RfBL7!+EK9Mye(qCOMX$B3-ax5*8k#M)$*};d~OJ36kjL|+= zOVLF6r&&E$qeVrYa`n7f0JQgF*nPZstoJDIPWL^cK$%0M0;Up`2p?rgo~Bk}S9)?M z=zy#7!smY~gnQJ#9Ut0Q{qm$=G{`8=`Jkohzl#itbPOQDeRbRwY4)PCl{xz-J^7L0+hvl}xi83ejR>8~ zky|Bn!f>5iL8O#r(%*YF-(wrEm)?uJ1Rogos(Xr@S4TNM*?xUdaCN`yd6G)+Z2QkE zGgfKFoNRUwfvGIL?36_2scok)UPnT^IWI<1i|tqh^>5_&(O2Mm!6a*9rtY2Pkd8+S ztRnDDV8T+5WzW{IbR#n)i=ZC`KE?Wvpf0U&bJ=dj-4|;gpB8Derj6*c9SO|%jbGp- zHR_0bqC?X%6XtO@%FY9?-*v$cwI&*9QtHw7^d^K@4WPvJ;GtGjO`+pec{QuE^h zJ-p|l+3>j+GWThHCOp)|%021Vvj+nYxvNT1hpBY0)6eS;$<~RSBG~ViJ2A|^Z~oR- z`X05wrTi+1QMtap%2T1FCPmE{VfkQl4EyO*%~$u+FmO?TcDBp>S!qru4=@J6c;-he z`s<+9^r+g*pI1-|wf72RL~J5n=1Z|TxpT`XqN`nh`lQ*I5u)rS0PJdM1xnsHG4F%^dPcv_m?@V1Jv275-(QxRbU125wzrJItZh#0*9FdUz z&kU_DIrgdt)KAsC1<>1q=)V_r5s!g84eQ+h%(}KmO}8679INuctIbv+89wK4%JR+& zcz&3r(ogWd^iTR}O5sH`*$CtBJQ6XuJo%|+Um0#JHvNft`0Ij0c_Hbn$qFGC_Gqo6 z_S`ML*~6Sr@h1X~!rePiGXK3J`YcmWds=pO0YXAi_6qp$?2Ot6k)=SjCEI80;;|hV zV^bjYbr#NRQzP+@$YRq-$-$lYtrCj_RRd9=S_rH2NfK+EnQJ5TP3*VJiII)#gR1|)SrG{je})n5eh;3rm5a=} zsduOIqn2LT(H%gvOHLzAzhj%%-n%>x^wYzS;z|lHpcO;3O&1-3LRG`VgCXHuN_sL@ z&39B9OP6>F^9&#QzwlpZZ;7n^_HQjViK8;V=EI zbjdyPHuxd5^E04r_FEo`hJ+I@*c<20~(;Mzj5H9r8RByo8w z+Y}hKZ~Y=$kfGFvg7fu1(n?*UQj=K&>mP+xPT>tredEyGJ-1S z0k|l-xcsBbZ9j%0-HjZ$>T#aIwG-fx*3ezn;O<^REG^a8W+B^T%N98P0fcGBw`RuW zS#;+xP&mtx8z)j7oq%@cgmL&U!B_6Bnl-3+nd9=)AO+rcC@Cqw^I0IcV23K%1|)S8 zs9QreSOmYly%1L`omz(-`xvR}ZTI(Q?Mocn<~gH6XN~A>NUcA~E8n;>6PoHOSXp_O z9N20#h8#(hu&3U9NUEzTVJnBfyv6u~prvSE^lzaKJ5R(*tZ{=&zxlp@MY%*(f=gt5)oGFg%H?0G5&!co>?xNRe3M&r6bve{ z`WqV2i8?Shm>lC-rgpfgIWlD=wtxNryH2`rDsC-urlj<3@lU0qOjrw?;PFf60xk)C zlgm22BL^dl6Ix7%keBw6pyO*KGjRdAw~UFf8L4x3nv5HP>)cQ7eCtoeG6BMrJ@zCy zKB4x!HAE(0W613)FXipyb;BY(5!jVkpKjwJ0KWwPWO+KodcVe)PbMP+Mv992-AUd;fyGEB43yW6Mt`I%FMc--2d2WHSk zRbXX$SD=Pt!rz?+1R)xghCBqZbv!wv#(sYlo8d_Gd#P1(auG`1o1~1Xd3s;BSk-OeNpag@>y_$) zyPqh(vJ3OuVUEF+F9_Y9H57zq;gIK6PdUA4%-aBJfhdU=-MmcQm)-;q`SU$-Z$BMa z6(CtK@?1B7jjd@b~^Ex=NOhe&m6yl?axa@bKNtm_Wm~Cc+ z4|ma+8TlloXH#M~?ZNQ8GBI0IHy}9&fRm&{Zcs#4R`v9{RkI8TJbr)aD4Y8>?&<2z`&v)^MnNWt8-1vd5~fe5zuOj zC#(_XuvVY(PbkJ;@_xH3I}8X6QlKc)gO;pK9?2-}PUtT}A!1IZhY~>#5c1+Zo54k$MLQuMSm%S-5W*E3o{7B9%ZGge$Phh4q-uM6 z`7bc3V4!|Rz8R-i8nB>-NU_Qta)Nb}&=G;CqD0TNw|<{*J{>dC`Bz2Fkt6u-)}_F3 zg|1WTxBh@9=&VF*I?<2!GT zt`|g9S+F{eD_Jkt83`YifjfpkYR1|U>;)qpQ@Y1Lb!zu#-cGj$3>@&@MDUXqRP^tK zpQYPG0Q%qV&Z#Qx0mvF-r%QB^0sT7Rtwt{>ioZX>lft~1@(yj-%=)U;NK=w&D(rQs zs)V=c1%>LH43t`yZ%9yG;|)IWlCeer%k)q{We=jHF&%DKA3g;p@a8`?@_0~kFRc=e z{WS3?>ejy}hKm}+6dhGB<=iJ%3!wA%GPBqBjX z&s%!eWScbd27{*cP;y?(EgbB>Iku`=oF;OLvEr7`2;toI)l}dDbUxE)qw*o^A((gb zcIaVzD{!a;Kh?-`grtj~5Z&8W{?JXNg`W9!g9T=F$AgKWBNgv;0Vk5*kZm-#S6^=f zzdX)$>o$j$w=|)cI%~;f-+eIvpu<;sPklsH?cM>L&vdPjz{yhc(iI8_f4+N1y9$k~ zl@hj5#EKRc2OOLptzH-Gc9DKmt2aaKxOCd|Zokl=syk<1d_fQf$&-+)^m6K+ zdM(eI&{2cLQN_Ao#>fw)#&8=7tSs0aOTtp*DHorFx4q6N;Ii!bU>i>`@Zk4|{a!?# zAPqAP13;M09@L^jIOoM=RR?nU$B~`4%j=u*tVW;LTh~mAlk6xb++i|qJvK9U^P&yb z@X^Xkkwy=dJznxLck}AN^`2ZU8Ipt?WI?@gHOKTd^;?IXKc+DQ7O#H z&((dMu-rcQOFN176OYIgvGa?_*g_(Uu^Y9aA4xdcD1n-w7O+m-z-xrNJPx<$AZb!P;$qnfV{Y8n(CSYpR3H^nLed_N05lI? z7QTO3IibA}K(lIRAV`kb+Xdwj~Crj2<^*wI@Hdo@g|6xFXfy8?rL zPOr#!u6jw0H5;HVzk?6A-~(IBw$4i0{0fM4CtKIPAHQ5yXgUa(723aPi$J*AY4MT) zyj)!WReCX@zL`@LUvv?ZAY=Wb3Lo*K6Wnlm-mWiag1h87`720V!S{#Gm;cpVkuvOk z7cd5bxHLqaDw{*7b*y(}&K++{dtqv>mU*-Xy!J-ORIpGy2}#*Xl4YLkN5z-~K?(TN z;%iUdwbO}m@#fM`ukv=%m|jYDH8mDfQ`0AjZG0q@9aH`;IHv)exbt}B`ihV^f&?#i zjiyAN-B+)C5v^oH-rTG*z=oc9dHOWqz?Bxn(TqGr&fZ#UZR-Ih^c`#1s8K616@||+ zd!%DUgCEF|--N(q*D(|yC~2Cf$V12rKMVU!#&Lrg`2rF8&T4fLe{`dD zdyQ&NEOSJ7ohEiPUJQz(b2x*4uUI&xeq%=aO%mUUtMPQgV)^JN*)xQ7= z<*ON51e{0C*65T}hqhL)jjK2M5+rTjRVb&N_jHNtTA@^XRa68^z$VQx%;06^1m#_ydq7BQhx!075s4z7O6k;Y zz4%5OaU$29j8mYoc>n6tNE&97(tK#_WcYmbsOR#w{&RBYuS_l~iyS>xWdomA9TB<* zgEa`IKwsE=MYn+)0b%L7BtC+u{=T!jI|u5Zf7*P1>+MSJV5jWI=7qRlEMm61&^C-V zEFMvrCT}5b3rA=#5QcD>VxS@HKCX|>K!5y~q>JVkehWSF_sRymLZ>y*6YeY30|%>g zB>-L<Q;FEj-`B?&Df{7NAzN$;MW=pvc={NS5NQFHIp zxC2s|3~X($+xvl5Hp`r7#r!r;lLFR0>6v7sSZLM3L_<#`osdf`ydJ&$!+sE;JOWQdLv}e(^ zt04nvKTty9QzzP9q3{(&ecvo@MFK zN1c#vCO)9=L<i|Udd*y!Gy`_JuBL7(xJJ1~j(%zSai=yS$Sl;vK$@QJ{VM_Cm1-7U)oh!5!^5E5_ zmNkvN@SCNRgNj`)-g^lQP@cvZk>=8Tm3Y+(5jJPhCr_{7UZ5;(8(SsIV@;gg38Zj2 zVOR^^>_VDj|LhKn5u^4cht^Jd+gRmh$5lkG?`CG3IG!MNW;L14-=7$>j&Z3<>Q5+M zJ%7OoIxsrunEhFi4wOUAe(uIBbp2XBJD;fKVn><>&dYnPs2bd_tB~6&!t!Y3<;A9_ zMObFJ{z0>kV12A2_1-x+$AUnCC>eE^xHc|&Kqfx7`Cl{0!JqrXQyK$`LV{(XYX8i8d% zz2tKX6fo(L`yjD{M0!0Zn4faodUFH%y1u z-%NY=AQav~{(`i9GzmgvO8`cE;v@H4yB+@wgm@lMTofcyhW4(re;Y}`yTN@4%Tnjq z%j3H!^ns?N1oEa)emBdDvk_-x)kM|F?r-^3lO3eIuFBNVDA37RGaF=gLZ(~5(S^)g zt1B{Swg>0opvBDL*J`7KJ5zfO#%0_%0Bc#(L=~UOB0R8@LxTS~2%R}I0nq+s7n>K^8)h^)mcBc5G2YZchz7_*|hK5x2;=hYUDSVXXE% zz#e5#2aZPJnbg}M1W~QVgb$`;YF_eNaKu;l`vXymGrB93*cXjJlsw$JWL#1xIEVWw za;fLT%&ST`1{pn4^v(Fvn^|SmO3muUv%8`8I7E@tgb+xrC0Rz_9dSzm>{V@9VDA0#UQ9IOw()eI+Anl$0OTj+E0%i`3RFl?|cNH zuJ^|uK%lPf;P_ebSYgUuNq+Jn$zLuPJ&x~x{g9tO<@m#S5VkvapBQU%L)A8GMe{1c zHi~OQnQiFk>)P=xcAWY5T$+a~{=Mvyrh|!HG(x+Ly&tYE?{Jo*vMC}6rII!u<;6#J zMjVX_5rmBHdc^j>kfX*VDSWrRfE`YVv1UFmg{7>5xRfMIQjp+Di+OWV>@&tg`-tO zD)sSPALn}pvY@V3v zNkNRaj$S2LX#F?rj6%*tL^&9(xDm*2yUa5V&0iI}SQLBCU zAeQ7rpGDc%o8cVb!PMGy=;h?j0tCZ2m4}Lv25)<`&I)K+ezqvnxP+KKKwb=-<#3j^ zoA&@;G3(BX^YCHXJGXgy^YWq$P*LsO(&jE92XV9+FVp!-S@NQL492PT4NVX*PO3-^ z49SbKEHE=wPnpDDtR($c0_Wz`SAO;mRh#&{{xTqfZy_Ig-Vj<#1r^i*S`3EZJoINi zwmF>?J=Igl3q6U$pWun1FB9U7!Gvtb`@x@8Iq0*uWexONhM0rKy0RYV--X=KCSA#n z^YNs#kL%y7X$Hy2c7qR3WsGwJzYQQg@ntG?@7Jz9%?6bp8@a!hu;0FzOd(qR^mVkxt7Ag5G-UdEF7&>0smP;Cm|` zjFTN~us`~JveG)mBZ;YKOKkfy4F_ssZvv;r zTE}Qp=|J8;E1{2exurfd__{S!;G+gRrKPR`UwfS;%S-j6A7J8QD9Eyh>(KixBXhgf zc4T$ZwMwmviijaWORB6|J1|X9EtXx`Qz!K{Xtza5K^a9P5eBDR^doX!ABUYOu%4q3 zV9k?5H*6-4Q&pK=mPqlKs;c*Dnw@*PH1Ph?cg-y$=usUd6>5aEU3Gz4tzB81Dk=h1 zkptA}URwuqyO!U{=K@E>pB*0vApw!%al22E6OZ&_sz7`_i|x*TFZUl&n=;crmzFDf z&}efIQGmkM(pvu}O@opb@it|pTr#5$v(O#3{%8XpDRe;zV@MC<4R)@;t@6)nww0e) zQt7FZnB8z&ry4%Se{9(&8e%C91{d)b}NEr4Cr0$ZZMb>=^&;Quj2g@b@Mr8N6!@Ahr+6(z2$aC z-ZCx!CGsLR9)l^pi`c}WGkXm|_-FT3L;&^JPs>13l!vd$*4_ups*%c>=vU;|q0ixL z6;^$J>ne6^(jp?S=c?|xGb|x8ZNy}F4mw}dHl}lzHcDFfe0te1^WCx51sjJbBPDmN zR0ZrMH!Sk#?Jmi&*Fqw?zEPj)8FS;9*Snk-^4gvqFwBa>)cBM64GP0oim$c$lpgXn^Y2gTr2lBWa7-U{-acF`~%Z- zcU&PUGM;YJW>L$=`kzxALH*!i4KCGGUv-b7)}#Js1j+&|#x3Rq98{40xWUcR{sI0y zFt1zaSuP;N_3quZm<_;`E)#Jd`jo)u}ud>~4jvXdI$7 z$D{^mka#WAb-S|Jyo9P9)GfSF_EN0=X;dp8M^z;9Yw{b!o{ut1q*I{cW9!gf`wFPd zqFf+az#}p3U1|iA9QwexkoSo!dSipJYvw@xS=HA8@F*ei$zn1X@Q=#@#%&2m{5hQc0b6#OWU^(V2029vV=D>{nCR|+y{lO9yebc->I>&bMsN#gbYJk6}G7i5AGzP-vpzs{03+TdyCb z+L&+3E)6%=9jFS=v#kC4GMLxYNrIdtoax?jXhtpr*I?BUcY7ilmjsDC zomY{3h2C68+J=v$yF7>}{(aUDC{UJ<00|=YCHck#{6cmowd&WrkFiN(i@!O78$?!`aXSn*N(UJ zZ`+}_D=zGub!7PYiul!=m_dCy_$7kd#2~HCM%<++pVUCcdFpELbrbY#k6}?<5Hv*S z#sfhJu66F5w}var@e?VR%Xzl-wXolNkK-+)?+HA9lT6O&_Ga6Tu$uP1DLm!_{g}#Y za{K)=#+8g|(onN^>eD3=YJ+ISyK_gOPQfjUoP3TJuBz*HTvE@|hE@7=dXzbR?K3nV@3 ze^SV>Edz_gaT zZ|BmjnUz>mSloS40!?%cI1O4{iGYmN= zTLE`_hayP{9f~fTF7ylKLKl$Ert7~BhliE1DL6Xh?;jlfB}YO*Cy)Ggfx@rY1;gh@ zXz4io!+@aD$d1yAeeg{GCTsP9Re2;iE1sONT7v9V9O~n1ytM93oM--~I#WD+ z=?#~{!XwK-=O86OC?E)cK-+oS)~XCXw3tQEl)}B0Llqy|oo=0q@d&5YOgK~TROHL2 zt|C-uduz1Bg1a}bw(pHFM9^&w=4^Z<$gfep`O4&)HV!W)oYOD-1p<_hs7Vn{kwv>Q?2jmAUTOA zURh%1-UiP>bZiCzvNcT$JHD*UUKVa{O@7Q`Tv? zuKJt`QF zQ(1DjLb;Xw(TSz~pAFD{1@=f;7Y2M*AozFLbfOd$eFl1n;&gq$f*}D%s@iNQv)tf>W?IMn{kni&95t3WNHoTzZSel6o=yF30drSf=iVZp@U`ob4dL&4yo(oOs0zG zkJwp14q$KD+j3I>%;lCnSBzE{C>{;uKHXU4_vNf@MiAq+@GsLVI1IqVA-{`|`9fxf z+~hPv89`H@k8VNmQd!|2@wFv6{^mm{3P?n$>0odJH}p^}psz7k=R8RTgt+8PeNd2n zj3BCWKtRo9Fz&J!JXaBjISTmEBQYJtr13_QV<}>gc5+RhZ-4{M*R+%;cs)` zKb!-*^-ZkLe^QAvr2i%^nH&-let3};DQyg%A_zgh@RSViE}pv(WuuCCRG;P%3jE^M z&g-)9@_-eL1!TPi^s6+(!zq8 zKfEr(c_nW04;kYBjxZov8-=Ak=hFXMjE*2f>=%R!Q+GF`|BN>9;@BM_+3jWoP7C3% z(DhDE?PCwI1VK+A!o0%>XW2%|u)&;2<3~OVGv-Yvk)EW!3EZOd4JTA1vVu`NfeO5^ z$9Y=ROmFaG7_t9vc#nGic!ubv_-yIX>La#@B8dI@%#ka>l?rjyqp1=2rw^mn{jw1KEpDKZI@m_UroeD#PcefwaMrg#@G8C$V1kUq#W)%B(V(jOUovL zlnDxLrW+I;N`PH_dgMD;4$Upwt!Z2SRpSY*Z@tct=u3Om>tXiuZ(W*@f)7 zg_9`UWuHBITiibUU53`aHP96M!!n)e3Lf;gb4tMD%wIgwW0h0;fqC-%2n&A~nQ_LY zPEE-)4;?Fv(dWMgAV-m5`nRG_;LQ&VJV_`v+O>YeQ5gd0+{u3`p-BqD6aG9iQf4YQ zLM)S$6uHh{z_#%Z(xzBs#j%C;cyt$eW%29q0>j1g$RWj)9G!yN$=YV$ZXp zvr?{!b*PHCj676)fVJ~N4qn#y8rCo%c86@j`7=gpjC*#(L3&>Ha{MVoId; zfp@fO)IAQqZK>0PXN)^&icWeBXC(r^$!Dr02YMoZOHJEke3e=G7EO&r@7W1e9zgWX z-H6NG_~OTT0U;i)n;T-E{GgY{`OmlO{seSx{E?owgK?9g%m)|mD*=SU*2AS*U+o$d zFq$BMvG%fv=U-h5nu=P*3+GqaI6smWS&uN_0iHPPtMEfTRHA6=9)k1)H}&_KOTK2W zs{aj`EvWM;sO%r1w$`z@P;o?rF76Y9&O9gOlyx}}1_at~e0 z!d?Y=gL1`hhz{H<6ADY1KrH1>#xn~$q0rP=1g~9vPLh7=uZkoIE7l`>q*eCh;J!@3 zF;)wAj4Xl4^oGD-MdRYT!?Ly;y$4=SUL9kjiO*IB>5vWd{osW$2- zCxwNK-me0`*F0WN${#fyve)nmDr`8AoWN}%W+X0ULD+Tg)*bP;RiLm<<6I}~$zxpL zn=|m&Kk+0viP98H?TO9J-rnBDXg~}$#FHZ<`Y-!L!AQvIn{~Y(st+KnRRKo0|J7f- z=!+Xl-tCvh^N1k*Lwj2lOoxG&>E?0lQ@4eRSJJA;i!WY$>D>h$D%C{`VjU@hI^$>& z=C)Y)+hyy@_kAv<{wHn563XwfxFs_?+MNUhb&+Bj9TMM2@a4r7C#QmaXqO42LL%c< zc!);4;hSF(oGWMBiC;($l7|)B~m~*HKKcu1Yp8W_k$A|Xyzd%(Og@l z*+Ox5tqlf4f_7_AxhEK~Ill={E*|+X4ueqqurh>%jLQ=BM)H|&*&eG2BzFkR6JJX9 zODK~3rtV%lT~!xP78eKKk7s=@^m4M@9;XsP^E6oKZ9+v zvB!3;Sho?b(ECZ3sd+z#igjPaDj+0Y_*YkuQ%(1~pbvZAz;A=P%hxG&*GTl!l@EPi zjv8J`yteGT$B>L=D?gZizIuysCfT%$oLYdt)bvnr>ZW(x@JHuvFklV-M#iHw9?q_h z)tQX~-;aj_ULGFBK$YZ*eu(TVh%Dsdv(rNWSJ<%xYdpc$tQw z#gj~<$1{yFUIn$7sFObW|*02-A&6V|&Clm{!a;HLQb z&fDo!vg?UP>^#jw>vl1a8fZ4hf5EgVay3P5M(le@L~!!EM(s)nB!C>Ni*c*(g7{FB85^`|JE>PFccR+03zH*Z zI1a_4)>TumL{~4){zDQOL*|rbF|99Zo21DrA8L7Q%8Y7{ zq^6(2WFO5~Bp}Sjvd`$l7crC1Cthybmw}OY`H07^x5*}-+t8c$qn~|Vu@nE1YIDdb zMo>TwdMEyb|KbwBs2;vTQ9Q2TYBvRjL?A1_ML@xGRX;#*UG3fh%tE45`^w*3$);>p zWK>4q5hfsgIs{VcgS|H$?LE5;h_t){)*4HwBIcb#0Z*J{PW>&x1KEcCQOjT zKD|vjFyESP&42#mr0ke~W@TR(!w)csg*`l@>w9x&x-kSjTw~qbEirU61ivl1!-JSM zZ#rlr+`H}?SbER+=d3g|PkLYOUtKgm3O`CI6j{7my6by}Qy{*37iI6Si!7yde9YSR zwix}lNFND*d3s|_6(aLf#{^Xjy;N3O((>i~pm6wwd5pXf8F|&a^X6&ux}(Lb2laA< zJ@8V^KE(2LTa2Z9Ce474BVVvrBhQvLSr+=!Hh5YRnKINy#6nq(s7cCNXkmOPNXB&0 z7=%h6*O1Jt-zhrnL_$a=QE*WNWCnSApm@TFs(b&yA^R-$aYs{_5$G{e=aL2)h3$>c z#DuABBzQhN{A`3nY@S;i<0`&PfYlsBBlu#~`%Jqs)#@Iy)y#C!g@i1{LuYj>1n?`$ zJ;oKojE#Zohhn?S1%dY7Ftox@A=M3D4O<#jAq>{{LBCTPvG6gG=*1a97`c`zUqd zH27g?++9TWLxG>^s`KY23(-xky9VeQ*6tTpBeO=yq`xW}S7()5?)^yGn_G7Rjy$F) zprT^Y2`wuGLrF{5NKFX^aiy@1j!MRq+==wmKa{;8ek0*ajmdk$jaUo|EccGD$~VSl zPQSTzPTnf+XZCwa42}@ya{~@RMrNVY@I5N$Ez`>GNU+I;za1X%b(nE)2>R}~N8z~+ z@-Kg1t>7{%ysxom=%~gj*b4RkE?3I1S5b4rS^ROw?bWD(tv!Gv_u>qc2IXrG3?cT$vePTOG6LtPYfOwjN@Ri61w+X+kn;7-rHFq*C-;=%oCoQ-aelFapyV)o4NOM_%%4MK+bdiW!wOf2@#gqhA})2d#2^1?^|5j^ zI}y|#@QBps>a4(ZJfA0j@=xsEof4Vqpxn4%`snrXkj^NS%Q$nz@Ee4{9>`He^v$nn z*p1&6(QP2frXCbn{yVKzBUv{I{7L0FmZS-X=~QbVep=-1Ij?^XCzcz40t?Bykk;X1Cb^T`lXeP5C~{?_dw;xK0GXy*|xF>28)Q<^o| z26~`kQdOKchQVY=PavSA_pS2Wa`XU!r_%fv%o%0lK!C+W5%+peK*D&uv`zZbnYpl& z6L=M%On?3w2gqk8gSjg z(A*eYx17|_8s(?`AJhhgp~#HckA}Hzp)|PfZ9%c7{SB_$5uNDr4*A(DDNP7P;O!*xU(_h_7kQs^k|F8Co z^Y`IFhUWf5nVKP*d~!KPb-aZrV*eT9aQv)cv1nj<0?sgg0hF3bu?Dj=L#B|OCR24# zT#C-Kx|5{!WvQ@_Se=ST6*7C^$tVIYoB|{!HQM02efQDn4fO*MiiQB8y&h9g&6sn} zr2JcL)AT@wGe3AI#g;tKTX2oc&B$tX|t_2#soouMfXw@_lNpm@;3C-pZdj_SC+|JzT*BcQh@X@> z*@n-9Mj%9DK&?yP$OJq_-gR{_kdV8 zQEFnZM8qK>Kn_|i^qemwjj=ohkT4Degm7nQZurMcmK^vY7?pp z9J{bap($1}L~eUEU;Oty<-K?vdy4tolJjAMY?B(%oHv}tm@&n38jXLyR~W>hE3@EDgmEx8b4j2 zasFn+D~lBMosM?AATrJ5F27YYV`AgRQ){dy=w*26yFcv;^%WByH&cr*p`u)my3dG{ z%cTn`*_krE`&rjE0^H!gMC;%l7?~|JmPg+hRltC|z=SX&H z03lS4l-NSXsk#BjpXt-{I`=9@O$6H&l^k9sYs8X*Coy?8G&d!-xbp`pzS&t5J+ zR~`5tF*qG^L2cBedy#UtlL2OoUxQx>o;f0}F(_;QU2gq8$O+V`JLWrm(N+DQ3zK4@ zKeg3P*7*6gxcNJk=@K?6yY-Q#I*vKcT^OB-$U__T|#<&0D*Vr@KTPDhKbf7)^72D+AQeX^fRUb*3D$WLzY-I{DcVUMtv+ zTyA$2)5_1ut_h$%Dz?717Xh{;2}86>W=XkX4y8{ZN+^$%#r8N35xNW<5QdvH{7&fW z6w8|y0O$*a?(MjUv6(R_@~US4g0?a*fi|tIQ0SEe?@yVkMddCK{nFRM4?wRNHlYEQ zs4on&lh3_-lF8S3=wIg_G1Lf#vt#=|Q*xx_Wfz7D%4ze$_93J5y3zZA{{Je!rkrL& z&kEj$x>n4V30IuV|mW^x(j;HB9Mw$Njo8n@iWqGv&dm0rm4b$Wrik&>iK zW!L+kU5!;Aq|r678Xl*Fq2&ho=?~4Y<=4dp0o;yqGor65P};&0rP*9`Co53KilfKt z(Nl7F-by>}qHdCW-z;N26ZHN#X@%TwXg>?FYshkczQgW2kDUa4J)2yDTC`EpB*`Cu zT-E{L+2%U-S*77LJRvgF{f6S$DsIprt8dNjOAS(U~FYyhXr*Q(AVm+ljD znYzpRr}eoFWAUp@8T(f#sIns``HRq(Db>6J7W=28qpp;Dw;K?30!HNgotp7`(+%8} zvq{qje-O)5=A!64i1bcG^dqiF4lIjr!?5hD;#>~U50xs3)GjO7Ic6Bqc z4Oc1Y4Q|1T$0On&Uk#se6=9#KL*o;T(Ul*0T6?0KQeti4Rc5C?V6=aM;-3@neTv(pu7|H%_*7_m_CoZ8T&_WcUw#r|DkWi{_O9WRM7wuy z6P;r)W*>wS1S@Xp-D?I%N|7FHUT?8(%3&hw%~V#-yXS|*R905LWJLdmis8CH+>70- zThoi8;9p$VGKX4*b)6)E?)gHJzE;&3HTwRa;nOz%ALP?Z$WY)j3d_kcB82b!w!04^ z@S-h|%?|;Gb#b-d;N<%dA_Mw57%^W6z8Bmr7nue|=FfV$q}suF61fK+W z;RdtiKVA2U!(b%-tdMw*)ZoRFlf$lzI*{bSX5toqL;zY`ejyn|;037bW)Xg$03x2 z-v1p%eCf?(ukFbee~7}pBK#L4bhDeW4pUFK3BrJiEE^m~Oz-Uy9Qi+@t>m#RCk7iM zB2+JSMw|{1aA^p0*{JHNKyT=g$6&cFAqLD`|7K<%W6>qIn7LYRqe_%O%>y1 z=OHVi+*$1HXZVgNr*B;&FY4Pro|?I3#sdvZ6|H1|C}dh$K|fdbHY#B2wH^<}JrwA- zI48L^&B8SZ!dsUkN)LPLn%uUbM|N-RV%TlD#dlv(*%TY|L&01Z?FT)L84xLEZjknR z-TY~eu95rYW2Mrqr{f@2kAF#Hwi?}!w$@G5l_5X&?d%jU%K{F>2<_Mlk(T01EpVx4 z6E{Gu{7Pc=^|unMNz~#WSz~I*w3S(B6N3JMw~_4$*Diuv{YS#u3~p2Oe0J&ro4uQ@xt&nCR)CK=C@up&`85 zEkwj2o$;-csL)N=y!$Mn7a3k=dmVytUGK&@4P^j)5z?PVOvxEt?RTN7|A4k$_#$t+ z2>f)92_1LZir-Fo@cc8aaP=u{$RL;*U8M7ZO53cFH;`?9Yo`ZSHYECYf9sZrx1}0j zFK-Ui^csxaa{Ot|T5|Jz#%@g7H#bDuYDF(;N-0FxC6af~{|~jD4@l4^&{;(Lg#MHLYK3+b-+YgRWi z;Ip{S`)RpGnd;>1sK^~R$CqZQE=^TP_31#ZR3t#93l8)R2j5`eOmh^vz%MMT|U2l=zz>b>^ z%;e!{;PFdWXR7up`eF+E0>Rn%sSI--*DH;?SfQEui8KFt`yWfAWJWOD;<=MsSH85r zRrB4~?QO=E6`%J-2Lh?Xb=-CeYeYWK?mi)?g+Z{q>=eEjR*Z=rj4f#Ba!&eM`*$Tx zQnGs0JzQi4DENa7pOtw)V8E@DwxrNl;DvQ@4Z#Y5#}2{?dcOSJhVm3ID!W!RsQm7; z_-`q(_|9?VgmRVNa&S!%*pE$I(`8`G3G4QpzCZ;yUzvj}ZgC3hi$wd&r?5@N1H91p zEJXhXtP#M#CaRwjA~IHEXYriEtj?fAV>2qRVRYNsDVkH~O27R1ofla4>~iDCpAF0U zZMp8J&2Wg2@i+6pC7?MUh+ghNR?i8qK#mY^F#COOoqceVy`pD?3vuC6o_2fWSAY_p z%^Mu`CP)iWiP_}!xQ6x+bh+eM{za{HSBGo=gQqocy_3AzhgpzL=_WzkAd}eqVTm#3 zo~w2NTRa#ARL=n@JT(SwTCRNhS*$E(kIn!9)p#Zz)K>ODn7C&?5YxMg8|Ajn&p^^| ztIEs5rPSGL`nYzRu*ySqv0E``dPZtCigB`=yv8y|MXKTTT)EPA)o>GNbDyB>?wIFE=0x2Bm{J(9! zGl$=Q?`%&6Yxno$-UQq+pY)y5x5p;23U`0IyDqQ3H^&8^1|Ar}%EV+gJ*=?%TKgJU z+x}fQOsP7gW(>;*^sxi1|3QVH@q$BhWf3OOVEW$x!q+9dL*YBLeW5b-;uCToPBqgA z+o(1GoRO-Gg)yP9yO0(<;i5d$t5J(hL#H>L#?_I9pnO+C1|S}-D^*$EP~9cGf?4fS zEcm7}SZwC?6(lD*>X9jP@PvVlFYs{4OZ1_Wufhhbu6YU3&%`(L>Ln%?#U^KUHY1%kokqastzb*ojO?$g$ zs}xthXX(}4SIxy1f>-_#`{o43X+zO$`4+BgFPI;tpfEU?ug*qgrKu*6SGVM;jvo|tI)%j*E z`qA^}z4Y}Ns(zB=hKnNDk777jNk>@bIVad{R&$xwb6wdl*#_Z(Zf@vxjKaf z!?V@c#GDS@Q2uQ!3AZ9Us=I_T;5lOQrhAsGwb7RzZ$QTpc?bZE=-HbYwm6eE z!dfEt@x@k8N_mM6UCljO_lp8Ldjlm-`#W!%3$>+z= zHpIAT{>YJ83tJiDJ#uH{ zc64FuZQ|fkMEI_8TGwP<-HycLicSLUE;GpvcH$Qq9;z1@=?=U%pgvzm~ z=ak#XlF*@<+HXn0bck_CO>I{|XrQkaGIjECkGdwazj6(cITkP72`znVCbPYcrW{)dc7M8eOk2V2VXxQ5Eu1i?_@rAk(g@)>W%Zs=*#9t zYqHU|7yt++2Ma9=~|c7w@jYB&>({^zK~t@{c>^=`k~)E%;kxX00Y1) z@6LsQH)0YD0!{JnG(46<6vlqxeYHSxqKGa+_(BZZs}ghLvL)8 z3y_XKC4ssrkIpzl$Kg!%=urcNx4h3q$0HxzOZ|gT68P9d#V6RFsf6HL#K7AUWKPK< zunc1Xo7fDP@4hK)59oK|7Y|wk)}~lAIj5GqZ&PbJ^y!sDk?WR2?wvQI{K5U2dimPL zt(!QfKFl+qT8a_(B{y<1$F`-qLVQ!UCo=JS)PC?{K-bX4k}MX8IRY&H-`9d0u9;mGvywVeBAo?N&tJJbM1D0MAQ!o>4;@fi`r6ZrdgzJlp7P=XubJE z3D_No>w;D`71|WCNW~;tv9c(PZlgC)Mxjl>8@X>oqJ5^vPTBI*&37bye)EQ_bIj9m zm?$Ew*Sn$(@fW!IAkDKDPKIb7v*UJY5d7$)Q&dUmsi*54ci zMJdDz_y=97Tlpf74#CIG0&g%yDyYxNxiXv}7v9#SuZV|XD4;Y`r(D?ey8^iY7B7W( z!q0@e*?V{@KXHQL3B^bQD8_zWOCmEJnwaB}1obQC_oOHM02<4YRga>jHt?kB2jt(# zxI`)~5dHpT?gmYnU7vuB0^}%VHQNl`;xDl%O1b%$l?mVFV|zXCzb=BD zM#jX4M=HICHCIWf^e}ltZ6mJF10I29FZdXR&or@qIQXl}wV{Mgc|Rjk2TDPsfxnb4 z&)f-Y2@E76QWRQtnQp1j!_J-N{s<67P?bEygdATpVaw5@Ks>ZuWz^eE)P09PaDrBpYfv& zK1?#Q4iU8HVpBY9xw-SxvoGjG6AG>MmE`087Nk@6Wp&43QjGJ#P%i$l`KboinY=rE z#$ge61zy*?P*U1_`1-UdW7(Pbwe{O5)1lRiX9+xtF_?=hDX5XMe=w{26NLX-^Pz&%>VqYCgp(^!|5P9a^ssttm3^57rNMFur5S}g?N(G zeZT16UY+zeL9}2WTIH=2HDvu0A?A1`3Mjuh&iO-h-&P;CSwSRUABxpgRP(Uj$ps9( z2zn+r-flq>`|-{hii_f0247``VHYJ7NR3x1UIpLQdpM`bxP=i)1q7-Z`q_n_@!K2; zKY8@B5PDF#Z}54!_rXI$RU+FFyP(v&kZBgY7lp(qH6V7KEy(^ziLVXmfJDTS!GG+U z+md-W_bCV$ZKaUjjh`2sm^DwgQdk;#iM5sAQKy-j+KyvX?^4l+k`}D{A&rsxk%QHJ z$E81Nj_~Q-q2rl8vgni#7-bQ`-j;7YsK)SjE?4N+NNE$TQauf@M^|!wL6|V z08PGMju4wQW@DP+{oy(Se2eIH-#M;}>yGeGtlJTYx(AuVW}#xm!RVbkdFv)ip$;jl znHJJCVgiB?zP4rIuDGEt%X`-J%M_(LV^IFT83GzcSG%BEW!3cazHz#C??bVOuJ;rl zMlCH?sa_^bdn&JQpy$CYf_?o%l;IdmqlfoP0aOLY@8a&wioKb;qK|@ z4LtOWiqoPhHpb*ha)nFpAlel(2zraqRyfeSQ}otWAwg)7K}?uJr4HxyAoVyo3vG$o zqq-v*X~8vCGgXI#k-%X0_kB>$bJOsHnm8pHyj3(!+ScDS2MQ;m#XFp7iQzgka zT=E*v?vO0p?=NuCPga3Z9%OQ>^;4?N&M?eS}7<4Un7Yx=kpH=oy&`&6mg{A|&u%%jVwuwB_7NpW^$9vzY z5=8D`WkNClaUZJ@)fnZ#TGx-X{6KJ&R-R-9|36->4+^}-|CuoipI!S3Qpx3!Ssc8b zBP<^fu^j|=tWWzBFJRG!nJ8kRFzSOBjkOlbo9VmPMeGyHnX&}><)wuI%qJm&T9`}U z4pvT3F*Gb-0hMacH(*f(>D9j=IC%&gl_VG@^ts=$HsVP8ecNs8WI?D`7xv*IYB^?m z4xH_IkARSEr$XWpH}bh;B~ylJtLn6ZT29cRqUQOeBLbFl-5{|xax3DRjB0?Vsp~&X zkQaEO1;}X2GQOa~I*-e5y3DIu?S4Xvgm;S@yc z)APlN@xhu5GOWn`>rman;Ig(TZV4Vr@Y}fE#>BGaLUe4OJV-YU2;Q;S(-9rTXtK0@ zE+@8~x7>DR*Q09T8Dhd)n&L4LlYh?LhgON2Lu!^Q2{i-Eo6o|1huyA*{9@4kZ#BcsJ&zLK zbXVk3!6gUY&f@55?zyt7q9@#RtaJ35*+Pr7fmP!@-Z~q&s=G)DO;`bqr+d{CoyH%g zco<`iJlHw9Nz1bhp$HorhxL1|hYdvDYv^M_#$yMeU~jln6yKf79n`yfjn0o@<`YfA z@tIYRqKhAtdmICEPfAAnI_vM8t=YKzBr&0`3-ORI(!nc^Zq*YE*>=NHY9# zof@XwlNSU=d_0p4;@36be9=$7-4qY5x9T9q8TPmAXth}A8#4M34%##v`?k)T-gKRe z`qX`eF&jlX6Q;r_O8c-B$h`ySfC=iBR5SHqhEZW71)8c>0a7OOQkx)n%`(P5jQaKr zO1G1>yrq>C)Z*uVLl(K{)DRvzOpZ070)*wkXzliq2^rJvwHlX;BZ^@nGzG8DW$-X_ zBKzwGS>B8)?%ow_lYB(YZ0ad$bpXEJ$qyknJr4i@uY5H>>B-|@#uLq;%J3!O3Yspp z!|2x(I>1kY5aI>Av(A>uek>5^!>@h=Su5Fq#!ve`E)8WcaBIY3m%n*h8Sc_2OZUA` z$YS#!x>4b^zfrE>(azLmjA_&jrnu=68|f=Zp>^=8235HHj?J($T8NK>GF^E(HBQNS zjk43ui3yWZ+vZUipxOIVT$liJ!omeC6Q);7ad6WHMBN$I?CvJW0DHks}y!ji4lg-B1Q5p?@o#L>Y^AU zbv%^e@`p^lAux>4)4seakjXi`wG;HzdhmnKtE(<|B;;|5g4*R6f?#AkGCdYu);`pa z2t27aMC$x3c5jck;{JhI!2?0br=IYpRBVIIqryns-<~W})VHIZZrm-=$q}i4WwI8E z6~U%a47~EiDKem|?zVph@W1j%gV|R0gpCU^>A!lKTn1Ny^XH{>36%1-&tfEn(8|)J z?vVM%s%n7f6Lrr5jvMB(Ulj1W1yQfz&FJrHUZSt1yA*UV?J`>0qo{iUP!!`vOV<^>jV; zv^W4M5i$oTV{b`ztKlE&-4W^vW3M>9SJmtO!THHL{L60>LEg$Cs&ra?naF~-C&5`m zvfbiLZ3VcXt9O+{ynmj4?f)6l`;~e@sP{B9r(X37m#zb1o}a%1=idqmtnv%dvuQ1% zB;%`eIe(@o_`KpE{e2*Hrr^|nka)1aEAd<~JqYn4X6Y%aRnpynerZ{8xLhg?-~{XK zObiq3RonUASoP>1!B#pREB+8%8sUvTfYmp?uZi7W{xtk?q?o$d zpGAe!1>*TFj)qEU`AdRjw-(a!-&_d<|BqBY3-Dc5k_QjHpxBo&3nEE5Oiv`go{xV! z`Vz*?Vug@;Wz?7{2%6D_BnjCnr=D;0O(cQHoo%b$o=ORUzf6v2a4~WI&>RvL41a%# z7>UK9vyyWv%__qYW1Hbja&RyH$P(57gCg(2z=gR zBqw{3Yn^(b4L^!-*If^yD%E(_JxG@`q`MGg|LiFLj2tiD#>z85@uML&S`ly*Sw0*He2Aqjo_-U@d+~P5dd^yM!b^H+yl=?s9uu3^jRKECYQQC&5E7 z%ieAFk>e!=5!@Iw*Mg~c*qk>lK3J8%#DJA|=nay*fAvr;O{-^8z6s;S9|fxYld0H} zSwkNALj{L)U*#i)FEzZ`7Q1!6(sbzVFy8CT6qufu7TjctX1pXTe?;BK8!3tef*j!} zmbei4+LySXSxK~Tj^H8zKRm!=w?C?zG5o1E_Dxb`Eh4DpRHvjpj24GR!Hl_DWqL{gkeR<@4af{WP8LP}2z`*5 z-N7=|hXXx`TStR*?^otG#1YP)-_L&272h5f%fKRr-DhH?6P$3pVHUa+LxURD9@=LJ zNp^UyE_yFG9kB}^uVP~2)KBm1HCvvn8CgE6wM@x5WQEWd{@Hx|iFSe(GB_+V_&%88 zy7+r<0de4)Bc<&HIs6e>P5Ct3)JtSpF6ivvm(wuuxcDj@p)>W<23LG!zUT zZxWslE%@C=dZg3QHTeK$g1%FH%_pEtkU+zOIT)k7Nu#+FlB&$BiCFwjF8PO|_LM_z z?ZEPB<_rY?#ZDNJgQux(_OH!4+~mwWWixk}*pK?m&StYMGAYQsKi$6bxuEN&RfZF} z{ey+x^Id%CBK@!COukR~xq%||2^TAfvfcql`_d>I2ct9U#lTgCeL#l@O-Rvg6oG8> z4#2Xn^_OU8P`%G-)huvzwE5B@n!M$27+I~Y6IGuo!I!QiyNr{w^u3Ffr_Sbb@Ri^0 zztW#7I)NWWmYEJvrIilNlcdBZLbhp)+)4wR^F80Pao@+MLqYGUwa#`~FA+W_fO+|foxl0p1&8}u%crtsRMMfk z=LuU13+a$-*kTnPa>0}6GceFxa|0Dfhx^CxfXsqw)?#)F1-`>CuK7gZXZ_*zo_B$5 zxh4WG+2!MNj_6=TmBmpM>%D#cS=G#a!NR`s8zVwvg4n1FG!YwhTq2%vZkcf@d<(;+ z4>752L_pLH>kx_epzul5OM0w4Hr5ogCl4r_75=;=&Y{nK4vgL(VqQl#4&0m=g){z8{iEy&3P0M%ttKC z@LJjeCN{s)s?)oHQ&o1qzDTcYTJ|zYihkYg5&)#{#T7{4M@|HGhlhE<2ybEPo6WuS z$mpo7&$3WUl3)yEaTyT26~eHoURV3}5fXv=0|e&7eDMeUz4FJpSKY$JLLj?C;&?C_ z7_^3~-h~k*DR@N1H2SkLMiOUXnLhZv?gkX5choET9nZ z3!tIc!$arK`Jd^BVzOtlH!nEjcC)R{>_5)J=gs87=HiydE<_K{(9_WJq&pb3t@5~4 z;(vR+@Z2F^Amt;LyXarrZ-Mi0bV60XHpBn8>OAjrB*u`7a2f6%Ou_c{{2BymMknpF zMGK{SPvC#5A#8thcgXAQ{j^t;J*vI=SXOeF4A)-}x6tk#FgTC;<+O=gy5SrYNRf3l zXxw!*)a^lP^7Mw)r#e0`7@(I8C|kd#dsr1=0ZjA|EEY_3_|WNDj=o0fJZcx zi)pSFO26x#K+lLk()-V_w2?z{%NvdR|KN-NbVcz}SG1pHvHeN@$85-19c+`!48)_! z4SfHNM>Kz{}e5lZrf$9DL@x3AM-d-%z5;HGT(^E{pT& z)XIzI;`E-F{R=}sm-yJ@p+n^WK+ns_93)wV06yI8LF>NBmvfez5nzcj|4FSk9XYOH z+tCstL$@A@uuR`TEtvQvi9nYg-rXkn_t+GR>fPW;e}-Ar`gyj}z}ia_@E$oZ5CjZh zdssl{rKYtAv`c;O*6#1sL~spihpK!><%;Ypiku3%UA>Cktxk!Ys{Br>`*k6tH!Oxa zD*~GKer(%~b0NtVmq{$y0ne0yOM+p(t()R(cvfHYX%e?v2y(W z`!Piq$8nrbehCS0m6J`u=)!Luy9 z_}p+xcE%QrYsbOKf#tdGXY2Rdnk3b*1TT*xUFrT2@Qp?_Dbhr8jdfLAXax0P>(k~(IO-A2!ZQ3heTfR0;UM)$AZ02}RsiNz#&wzJ*K;$RsnCOV zAcgE!Z3=IS3UT7nkUiNIY?Vaq6ms1x(ZFYmTwO^iKCp`5-%M$hB1P5)dpEiYgh;j@ zWduipq=*j=pM?oI*f``^3L|%xxxShgdxEo1iqtvlKO58yx_zodE+RR;J>*)BZ1 zFdJ^)|8TwR{$!&NCO@NyKEK(ZOBg#0#7C4ZC%cScw)j*?%-pq)!^vM#58aJrl5&A_ zQjm64{&aNhml1EP_v-n7C$^D&7Bp8#dF%*R-RhU)KYt#S`4y~T+AxSVl&|3;my;c_4>$29)1OKc*?z!#}mVcE9BPBmq3c z1;mdUBXL=Q!I-W^!yItBo42JnAHB3JfcSp{$HOI$6;;>&?I5(tc)Jk0bZ6SxsSQ%w z7$`p=k{>~j{GZ<0l0kt8BiVsDz{gf}3;l)7M@C%^*rZnDCO9Wx!L4YST!%j@Sb#0SCkOsu~Ekis*YIB;b zZ)YJiH%+7l+tg<`2aD6$@BYSWl0$u~N^~H)D^x`*j7NiFxIz%$?bOH9B(w{p)<0`ph)-_Wn-=@BWsc-iF>4r)~0uRY^ z{r{TLbsN9RfckdMENL{11*u5ZG}Bxd{Mx{#tUobZnRGP_i7YFGxc#sKUVftg|1crT zM^DT9;QX^ITb}x7*JW*tKt&1$j&ejfuloJ@Q!ko7lhdG}KyXSC77_#oDQh%E0!cHf z49V+wT!NBHkS7n@fFmEWA3^fxjjsk>c(B*N!_)>&cqPx|_L@v85PWDV0&VAfQcLsX zij|G-jC;M!(+(iRO2!AjbneFaD%tS}ivPIp;A%QJ-t(*RM_3~u8^c530=`M`wUZ?)~|Fh?;q)6PZ`Znp?ukZ;KrQ+z_ zanJZJyC2$uWF8>X5;~gSz$))%1FQ&YC3LHe&u!FGH*vyk*5MbVSfF^NGck%fGF6Z$Zx|w{p9Q4UHuk*CIZ)RKRV-mNOaXB7?2S z5PG3A#;4?8t$~2CYcKc`{s}+{Hk`UY|5vSF)8E&Gf`G*_Ee5s!liElXDNypk0FBha z(}!K)n__W&AS52#5LrIB0G;Tk&&Vd5+oIP)GhNSpC{F&Cu8kH`oqr9vyS~2vll0<6 zqZb4O!xsBEB&(z9hBJyu<%?`-K!LBav$X>o(cjbrrTgC>R5>Rl-IlC@78N<|k>4Lz zi5O!3f1fSUrdi*&`)L-&h|IXunY#=$WnH1^LHZYpIpFIWutV?|JZu^OQ@)qF`}j5P zMo(h^S4lgnJKUc0`ssy4hEvVbJ@z~#y5K2kPO&A9@=uqu50l$TN23oCL83WMjsZu0 zRoavVk3qc}QEWse6FM!W(Ez_!&z=b|R?(io{N{|&0e3%radsFK9h~d`HaoxM^wbkJ zRSSs@no>t|L%CO>)<{11DPAv%$gg~XL&=W+$fA3<1`wMHi78VIznWV8Rj}@c@BWE@ z$>}qDGZpYzf}P=wC&MWP8?c-@Fixp|g#Hgt`u<6!6>@h{%>ykHg|!g z*nTKa--mP_?@C7ZZ~Ap=rH92+<40RG7ue1v7Ou=gS&zn@;m-zPml9#K^-flDsJa=u z>g6R1g42)-(>@6<_J>pHa{kN05UesJB-njAp8n7AFXsN@{q@KF0QdpkG|g8iG^H2- z#)gS!Ed;-U1(er+>+k^{q7sQpWnCgdZ^%l-l3sjAZlQ1 z{a-<;{eOW}Kl$N@lk}P~(Ad!M$40N7Etna`DnLXb!j`BUo?`SM1)0j!-sgoX;(7Y~ z4_VbS+)roy*Ur3-a-RPC$$Tb}scY_L)flkLJyshSTP!?%09d|ES5 z)U+KU8WHiy1hr4DkqR?T`sM&FHLv2Eus!(EOS_*7be7iUb~Q~R<8|w7R4#-GIP(d7 zZe!K75RDPbGKg*;i;0-ZfRbQ6dm=OP6yS!LA_}m-*b-(VopFv@840oP$Na(H-(4K= zslseDW)k`H!=7Ht@WW2p{k`-4_Sr_7R1yDd0hf`;BK3}b{*t(@8r?NErDZZ5OF{{~ zd)-k+4d?hmfs!35r6n$24k#1{$cZZA_g++{puaI~?!ClY8F>ZfKHH;`A6_3m_ z6EG$jh|WPK!(davo81k&*p#4>4?Ng zK_BJE-4g$w(**|%rUo2W&cktE21!CCKxi5X07%V7)h^ztCO|)oJAJE)IM=x|0&7!? z5uadm=m?i3+t(`U@00&t9Q@zCX~o`_{9xl0%@utr{al=|IHf{K1P(K>f%E3W zTb|I-hocE zv{CzhwDOgGrk#0D%83zrVjt~CrhtCn%kPhEK})+xOff*<_G!eC4!zJC{EcT`KFYXj zuAZ1wKbx^HuJ9xM@bUPeiXJTjCThj`{wl~LtU#AP{fI)6Qd~e{aoEskHd^w7bV25N z(lr1XF6*d*eHu3dwfCUjT=es$xf+m2gZa4Y|ItAKn~`W%7GEsegpCKp3vjXy>^Fi~ z;vdU*W+H$JQ}rQL8QU>O{}cX-O+W*-%)_$vPmN@l6$xCRgd4ptfxdHVm@mwBo}vi_?f z#+b2`H_p)I|CmiFsH8D@qs+FyfFc}=6(m#SjF?EkShVF!_W<#wr~O~UkG&T@StPpG zE3F|6Y$|zQ#=xoIZ{PL%@$~TWdRk@Wf&*h3$A-I>BFqUWQgAJ~M#wLKwvKiHnp4~U z9sob3ef!_iH}V-&+>v=ZSa!^3;|$O%jNR$E`YK?M_bQ%2@d82WjCJk5f8 z^daS_l7?OVH(&iI?`y4K25{f5t(Vk?=>hNb;3VPNvK+N5uof^sJ&MIDr)bYKIns!# z9G8@Z&0hPVE!aw2BG z%@#BBVxzwmX<&`3JH*#SN#D9aMX=7YsA32$}XEA+UK+w=K zv;NhCaWL_Z2+^U1N0>gs73>nl;e!3&0AVR7J{v!k{W*G5z`*}Be_R6w2jKs9Wo1^X zQ~@9UT1Wq^t)aZLd15@vHfzIjMf+MgSV?A-1i}(<(>EAZct`*^$;CIvm>!jPPp<>n zaVq=yL#X%r-2Z6+L4bdcVNq@zWP9&l%LOdG{=}4MhN0Exd1^8-$;y>m2-99JE(~6G zjbs%F99E-i;2o|ix4)1)*)DE*-v*8S4mEe5_BU<>L>mwpJG`fF_|cs=@_I@0+(2~M z(T0%OB?h}#^M=~-iei)zu+yI zl+$zFub9j}$?f5U2w@BsfQ;(l_M_LT% z2FC>md#mOn2?%sqQ3_p(eNo5jsQ&O#At8^mBs!%tkLFKAv8>m~d4AzWcZ5<~F*s<`nRKQ1MgOA8TTU zwC6t{`>M0%Dm|iRTE5qV!Za1Dm^4n*ew(0PKBx}={r~Lx?Eg8C`9D!jFht86b+k2U zD;y2TOUmvt+WigKQGaC@;MqI+3B!Amxl(6j){?a~Kf%1B7647JD4H2}qW;Ki!%~r+O9e5moeV&2S`~=%x zjdSdgH#L6K(SowHf7!bU1VgMcCFmBW@vwi6{`;ok@b{iBG(c}A$Tyf6Rm$gzJIJ{b zNPTYZafCL&NwTPWCvyo78f^Oqi!ApG?Mgb#!^k){@Brdj;WsUw3n2- zz-|CQ`YhthT?7$v*mZt_1dI4<;d8_PVNgw7$9b8P^Ut&R9_@q`^_zeFK#YI-yYrX- z?I79?iW)@M(-^gs1mjNc;A&U62i@a*GxFk}k8GhT*cLYD?R)wUn-bZDr$NSu{Zg$7 z7~4PWL*dj(qhEu5bNC?^}EO!HNfy)|qUiAc4DK z$0;0MsMNzvy@e}%DBToTS1WmRi~F=E016UuGa&CyH+}oo0mns(Yy>%yY`^z#0|#u| z-v4jgp|2Xd0^kS$&4>ttucBQqFY(Upe?R@UG?;oW4EP%;c+hbBL8jUzFcv6jUxI`~ z-82xn$ZCMc-xkq((p;_F_go69rm61byz=ENpZ>J~KqV|HSu=PvA@C1_q&#o?Nhhi2 zhm+b!R^-FxsJ20Mr?L%Tt|~fZ7CcxAc?*nC3ICziVhzAOeBK1WB>5LX!V5HR@n&=MAwFjI`jJf8xY>K|C8;t{N+-Z&J7MQPxv2z`f(>p8F#;b{IiNw z$2}*n&cm@^4SX=%%xQ*#LZAyEQ~=_m)$AR6{W$l86ET@Ig1H+Vuuf;AJ?X|^$jzI! zfF`$c`IfOEl0ek!1NZj8xZWDtL*yEtW`?{S!_F53RfvOD;$rd9QfO-2B}9toJbWr7 z52}sUQ(qf}e{K^7|4M88=Gx*n-~T7uv{4ZdfYBF`dzwfwQW}E@)0>O60qtRcxtmvK43)RtzW1>2|NPEBo4&XGv@)34Zbp?v&JW3wr*cro-{6;>tmK)p zjdl(XH|4Hdh=;Ththz=(sI9}!d}qj~tvu%sbc0b3J%dG`z#Q*spLhRvW+~8Qq#j^5 zulT~*%0c{$5E}$Pzl}Xr#yV*gXllqzm^I}*v`~{C#PT^&V{B34Y+RWgR^+S?10Cg-7|;&0*@u&9#sPxz4}gq_~fzEL3TwKya&PDIn_Ji?o+dwAQ?r-$;Vy1Dku7iRB8VsL7J&#}a*X5$$x+6J8n>k^fm<(Vt zbd%2RdtH5zXjS((8J$w$OunxGFMPY6zp*J!CO^W{a#O$e&YpVi<86gb2+qm)=(=Yb4H$|DW(RgB42=t?_MK*#PhI2pNaTu;d8|8!3+qns3>oFOWf7k zVt;{J1O23($N!vq`F#K6`)CqF;=w}+g+`MJx`Z!1L49qvb$KEilcQ7e8WfOhhe}>vbtZCk7;2 zX#g!NXKr%bQOLgL>G6u{Lm6q=rjF zxfR9$XGM*;0VWR=#nnKxOd+H=BL~sjq~x%~ zACzY&6svKas7}6TU#{j$o=0Py=%7Ra1ahhsDgTxITj$;wdTbPpFwvyc;>|O7Iicd; zIpR&A>}t`f5ikjRa4NiZ7J2Xi&!y7Mj@9LTni_}SG8Ya1sxZ(#0jM-@4|wZ8Prd+) z1)|+nPfz;8<1w>hre6|cz0TPJoAZ7iIp7QrPfrqv%mYVXlec^#7L)HMO|%y|C)G#; zdIuIzMt;5+WoiF3X4ep~qg;}Mb0aQDLj~Y=Pmxj7!5M!*;{G%kIF=^@SZORlr>Dta zcn$^t3JwgZXLo>cj^X<1;~J_4kFADKYMdAK{i%L>i`(7?74LzE;9<_emLr?tuZ7R# z`$!7N8IaJ3p`gSO`@SLH)o4&;D!2_ZS~>hAT)Kw(O8|KzO>q~Q`ff!d0h(TU-ZrFO z{8|<`vR;!i*;DU^{rC{Y9RUyNAFNmJCi!;=+-Qus+w;&)TA? z{cf0L)T6iP*OS3a``KB+!Bb32?uIRWPi}^GoNfAGH!1+rJHri^YP*dTaDXS!M}Z*JjAVE!yi^lbKa@7mdmtV( z)oB}idi~v7g#<(eOHCVk`-4tBulT~*#n;K{nS z3tJ|Bdj{UC+)mA<*MH+~-`x3@l0)H>r+aw)B0+H+KgJBHT`1?J4+Hr<)s#L#4G<6; z4P6cBLpaXaIKPxSq(lmEt1i}0nr1hFsqdE}@B9eH>Q8-}E}egL3mMpCPJ#`0UzQ3> z*8IPd-Kk3?BN`y|DJNAjV7_3xOZ8XB{Gp0M9>%m=8hA+q9_A26*r@?x)%OAKQ)ujY zR+%BAzgNL^|1*ys$L@H*fcsGn#+ij@2OjciP@XU0_`JywOK|%B{$v9~h?C6t>S!hH zy$=Njj6~0Sz$wK4iSO&cD3VsH-+4nm_sjcw+p+M`$(-ox{#MIs``o|p2TK}aIHaty z1i=^O&=C2>;OoKFuT2HyM8u%{5UcRT?oR*ewUy#aIwIsBkgscNF9ZxSj&$q(F7M0w zbQlN?4~Wd*{pn=x&VE!L_RO5Qbq3_OQH{{0`TfrT5s1{~}eXY=*&*TU!bW6`bAHp2X3Fm zFV2pgvh&OH>iK@?B?$6*Sz&eJB2yN0{%_p@z88O(TkSYsdQSd{_xI$7VE;Mar`8!N z|2{o_)%AaLJaK>E6b-qcmKC1fws35E{*<#h|G+34YXzt?4X^a^?wx<0UiZ_{%_-o*Y;X3Q#uYw*AiIbLd8}LahK-* zh)2V+_uJragdw3aK=8onoxt)dfhxlO0GD@8eNljP&_RFs>}{N2S385(0r(tX3-l5Q6T&L`nq-06-iXh@leJ5vu?gM!Heh z=b5}Ciom1%HVs7bDZmsUv%Uy1_23c|L_mG70vhm)Hh3@=LHG^)0#ebV>)!~lS7-l4 zMPKXqgZl7?@W?C&tmqD#zdjl);j6FO|A7sD89xKVA_OduM~%J{Tx^7|>i7In;HO>? zz>WS8_(!ki!Hxu#{Ma$Tmc8Zp66_=FBeVT*CU{&l9D5%8H21)z0p?H!2-e%bAsxS} z*#Ht(!M(jMfd2ZM>0BeMBkO@V_rRm@D0~qf5rXf66ZyU&@eUjX0m0`Ha-3^6k>LWt zNC*}ckO0A>NXMM#_unS{h~5AoXH0qB$Q9jxu{AyxC=X5cP!~asTknNLKt*@g?T}0P zgX~o$B1r}XciTeukp0Gzr6hPQ_!K?`4}{MIuzT>D!2lFOKLh|#@IVC|{(cUA6IenK z*McC@*Z;%aSAd!)a#p4F{|ND(0XO>g{rEr%J_rD#or7Hdb%S4o{`@ox_y|@+b*W=M z@J;_c_(rec3NQ8Gnh3h9n)oo^gaD)9fC@R_*FOed3H-PCC+h^@#taAy^;__bU&0l= z8xMd4Lx*p{$PAAMlL-Vx;CTvP1B5~)-UtAq5CV^I2~K!_KLy_j{93FCAaNjZ-S_w` zyYBnDy(-}dX8{G7oof99Yrp~KA}!t7dN&L@TnspK!@ghP#(b^}eDG7Bg1?0RFhUc& z5b($`fY`HA>Kl*X36n(D|G$6h!-IHf=l_6}`cK_2;UA*EKqLT2qkF*s6n+LC_$l~L z;S3{?Y~)hMdMF2YXbL;O&kW~SB=}F^lK%q^N(32v3n9XA99*YsoO}pfLqDznL)7SWXU znK}djEOWt*ejWHv;jA~`0w-0VVO?rI$?BE}K*Q@=ph5-^jZB(r2z(+29PAe7SmW{^ zko}*8{uw_l`&W?7Fk+TwSG~dTtk+B_yp2AL1nh8%q zCBOlz?PV%&;bSWv&O1bvtp0&n#dX~MI|RD3rg5N>U}&+hz8!5!v*OIaTF6O6`Z*U!klS-{Su;F%XBerm6ksG@Coq`;m(t_q9 z424k8LBlXU9mNVxx@d^?<7KkVdNf5cj6)i-Cy(BsC)!Tv$17*Q?x4?IqFcYXrE(+o zde|`J^>A$HVk~T=sLY2jB!BV`M@CJdBGA`{jE3_#xMN}&5j-#h^VQk@z#?lUZ&fY*{1g`N+CYE>h)3Xr3_U<#=Ysj#Rs=Xy6$Rt$ z=sU+bhN^defA8&nsNde%C;m&n^7Wg%bpN_%7D1+xUVo|s1OAt20y(d$-v754>ZJ9X zAPoLF-LDTT>tZG=t~LFucID-u^|z&Xj()A5o;2O5FFwp`=^g!q{`CEm(T3VOz>bt> zrG~T`$G2WWRAM|Q1v}GKMc;p~eqCSup4*t^5j}^* zO1jr-FlVnVfI%aFeeXF;K*>*UbQf-} zVd?B)Exm%5r(brre0N3t1MRSo1lTJMh%{k^=zU6emBHsqFbG_JZq=>;`9_ns-~IDw z%6HQP*E^BAiaM^RUS29x6%vnoH|O8o1?&9&+Jg@#+H_kcG(^DILF_gG4?Y-iMs=F_ zHj5x$Eebj|R4^I?b&o^@!M>08uoV9P{`iIl+h&H?Xh{LL(zW&c{=kKr1bwR@)HFg_ zF4z@2G-3I?({xNPt}3<$%_$y-dz3r18eG@4Kn8!5`6@7m^}h5uK0Z6bnfd%{u3BS_ zFjV;)SO)Bh4wrxZzVBXHKOOJy>tX#hat$>_CTiJ-x`!cf#N*OXzSf#NYK9j5x<^T6* zElCiMY=I-%KjikDk>&p6J}5%Y-v*q|-@iKqF<%7L=`0ZUz{Cs!2oqbu06-TV-TS*z zs2V|l6m?J_`QVp7pR#+V+fu|01;IBR+cV`WFfb*2b+`}su#CQd*AkZjqq8h@);V5$ z|7c7BezoU#xrcn182#WueKOZjZGK3(f4och`*heE{)EU_CD|GbLlGB)Mfx_OJ~`L@ zsK_-zQ9IqoJFfqi6-u!2TPD_iD_M^Hq=W){?Ix7FGbIvH`?*E4&VBA_^adHC+hk#)gb%)snY}CfC?)F zM}kI~Kug)-jlG{dpofK^CHdjCC#MXYCvSAE{na1P?!)l^?rEo8ZnyN=PGmHgdK@CD zML~J`JP|#~mFEPQ2=zySZeW>~0#{D$)FG6lKMk^v|B4jZklk%M{k1RaqagVLqHACu z916XDUl&`W{GVwzVS+@5#B%|ZIo-JZ6eBD(=Lcj2MX61!@CPe&GYQMX0S|YDt9`mN z`Yoh{u&K~xn1iqk!dm`L-h2Pa>xz+S2hB828pZ?D!a5a*@6(Bb>gOK;)k5fi^QXlA zU>wT#YRzlBe$DNaDt8Xw|Fb;+nwQJ<^HTI7H!t};B>qH(n-5uT3+0KDrb{~eod{(8 zdriA6miL?#uB)aa_z$s(?;}Z#$BCo0@vv(>&C|;upz6}@z;;@^Y?ZZG*Dzp3FZ2A z`{(brl4y5Oe4l!DP9G7B0yA3pS<;}*L%xT4#Q4ZQ2zozj zxFL(0TYme?$Vv)}aga9jh67K3{{cuT@|Sac|4EXBjTOfWe5>D-q1#PNuJlBZ{A0vW zFPpGdW1YGERYkmfDX+S&qi&9p^NTegHe$wijQA@_K@1Re^}Ig!*WiE&J^A2|R&|9z zLjZzq-glqtG}VR)(VZ{c;q(naseJH+k-SSi5GF=vygTH~S9Foe6%u?>S^9PC-WTil z_SJRpFz?%Fhm8 zVr?RuhglW?nlPv!eQa`&#D$z+i@Z6>Z_)kuizS|KhlbMi`#Z61susUb-(656oaqPo z8(kadc|A74B(#zmZ5in)A}yGk4Xh^sBiC)NKo`Id*Y~%EU#gG5_BZ~0wIDipVf#!} z5}>HJff6uR+_~xvf2Z8ND1WBQ3<(+_Y`1a+kTpWWQR0W!gm$MKWNQIU8ennYCar<7 zArl9{PWQ~QKL7-QbGOgrfT|t-2a9>_Kji3@S~Fyx*3J3r`Qjbd1^GVH|3$*6vz16G z9ot8LW3gWb8(P6cKLrIk4WId;3)~VWV8-fkS}5FzU!{7o~y6l|DIx=Y&LK>%Rw#a#_inU zGFIj6tZ9>`@yfT|2*-J5m|p&oIuhYIe_Hc zsDf@cJUEmy478VTvC`S@fQLQ-y}46=dx4RXzpnob;q7(T+){7uU4Q+gV}xN-fKUJV z{=yB<=pcH1IQ?_CijiK&cmw41zm72Rq6Z88ERUgmuw3!lId%$ea*H6N13o7a!-PUV zZV>@zi@^LRJ5Pj=46kcEj4l2PVt3J(U3 z@EB$YGcO?g?X=J|K;I*wcNKtbuRkmxeEEzB>Djv44gWa8PqyMM%L<%32l?lF8bGRg zy>xHf@i`N-r{)Lzp3_UoK!8E9DyDR24n{yWX~X#$@+8bbq&lE$rgNsn*8DDDbFqRlju6hf7@wwjDhEA#7hPpZ~DKkM7_~bLlq%+k5x_ z8dgcf^fL5ukniodzm7Z*=a;w?JNxkqh+IPA7ZADL8R1d_1z&C+7-0Qx_?`%kW=Z6l zTmAI8a(^agL7ZRZwesxQqD6?bkw1xi28B7}0}ee(x1ZVb6dz@FMQb(R$UhXS#j-!-{vhUXwZB?>nE*7h8c-_i}rccY1vO?xwF;`2mIiEE)g@0A}|Q zU;cDC5UAm|byI(Je^LWCxQhStoa9+iUsyNytL}crokiv5=7Dy0L%!G7!RG%YXL6~x zeH_z%;lHng{!a|B&kRt-3{w;_LlBw@2Lm5scV5^0tF8byfBf0+|C8yVQr$+1gXULC zpunVibDlOAZZ>RqQmc(2z8WoxR$Rb0u|qh-a@-d}2W79dBy;RHXeikMiH_MoOFDN) z#`O3kX%+(%o|OpL56~G{a`8qe@)#-4@x9Jy87_r~lF~?NK>>5jp<>}vOp26W+=jH^ zcsD?sktGP>L>}Gw&mga%oq!TD?>$h;GPJ&ZUE$S^I3K`6f)OVJ24K$mbzJ2QjSw5j ztYalMn~))E)d3yhzei;*&`=J9nc*c3-71ud^*Z1HUui1*D_#XZmtTAK6P$_u3jED} zGWiHtiDweO&#jvHdmqX4l0jHDV%U5{vmAb-fc}%19)oZj{ULneV@{%QDa1R915n3w5eR;9JXG} zKHTyw&WT7rM8NT@*GrUu5p>RL{=ng6pXdJkRw&<@+K1k{e(efmi8hB>(GJ zxPpx~Qbek`y0sHUOp_r+AnOj^LJ_<-Yw|=x$OI*B-`7mc?s500O(P#)o`Z6_7>_uW z`L(d_l5BTZ@_$5m3n+X>6t?URka)+NRu0Z({{$@mnT{Ri;t&g|8?1jEK>Bp6n|DS9 z$nClKWT%g>^t=@V1NCq1|MTF+iA?AI-%0Hp{V1?^0|7Hf;q))!{y6adekAV-p<4p7{N*2C#(;8{Pv@=qb^rhV zG?o|oT|gP`DqsHLw(PF*S>kdN2`rYs8Zgd#T8rdAH*{4~-h%S4@c$29AKg*9M%h>Q zRb9Dp`+tt#tp7PJxo=YZFGmmr{`EQL3{747-_GD(-P zjD~h6jKr8UA=ih40bvDMMnlcNgC6}IFTW>B+=J$e4ZtH2TKa~W0PQGJzk}LXYRtW? ztAlFNn0SO1hE~5ZFrfK20%jpp^RyF~1DpWEfPfBxX*9w;5=Yx3JwJ9MO6ZznD!L@^ z?TLo>|1;wab(zhMu??+QyBmhu2-+1EGl4=+xxU7-5!o42i1H~V}* z2@nt=^O8R{FUr+as6{XHw-uA=uhW(wqe-MyCb6c5Ae5m7tkP5y*J1Fgc{2uP1H6R5 zA?je=%Av&k0CD>!>|bare{>~W37q!#pYQDQ#` z{ig`V5Gptj8yITip~zw_T5}Ncu8Z-I@XGnS062F(PpXp9;lY*LulN0qcJI*jmx5*y zK~s@`!_(O7`mb8`(!jd@Pplgw(Wnp`QSXPO+j7W(HWh7ZWD)8h*ItuuODsbk0hGq0 zQGS7urk(xh7qvU6)|@}InYVvzit5_^S#vfT^;57)q&riYwkN4TX5<8IroRQC0e}DwZNS*h5MNb!T?d1BP z;Gt%>3W+>Z{o&7+MN7&*?xL9!9g2a92b0$dWf;e?qA8f|MP7w+_a;|{;qQE?B=?42 z02F@8R`1^X4!klR?h5=*yVz;JoKB?;khqLLyX8XqFX0b~_+aqP>_F*)pqY#~1M{40 z4d1WB67-JlUm(Ck?QUu<&;FBa!p51e-fs$em@e|YqyIHNKy4Xt^QL^CgFByH47`_ch2jTZu_a(s9>cg8q0&Qr|k2xRy?mn@4MdT{mOhP|)D`)b7OG#5S2Mqh4 z>SJcetf7;PBRxa=?156}Ssz>R_E!^nE(?|Z#rAJY`Tg7B;$bfKg(g?_t34NQVlb8e zfwk;C3ekdXi#j^YbO^;Ho4mN(27TWAqnotj|LaDfr(8%H%SvT~{Hc2Cxa@EFepz6p z$I0%Y1~g16Skg6d(9or*lMG+iXSf(`!H0D#a3qZJxK;i3S%Ck0dI!Jf-ATFhUt9m$jS-k!J%M%Yt(mlA$0>JMm8lJz4)>)uKmAviWTdRR6*~l?V<$bp_n^U12=wZY7124GuBIxZ7xkPaD|4Mh}|^qCi6vqar%6FXJvw^d@yYchs69adGdNmr^FmU&{PD2a3AlipHJX3@2=k?BkqmY zRR&;>^lD=u{qL5~=XAWuyCvKDmrMBhNMB(!>aTzN`5PzDAM{=F+yAXiyp#MiOtNMa zd_Efc?kJWVX6>ATQ>ez?r{+g>bvOLq?BDFyXgO}?{@;J9|NsAVLWJ1US%N^)Kn*2q zyZ3(s`>v%Yug951oTjdg4a>P<|C#Hp&G@d9jkK5>xUl+5#yR?MAMKW(YMlSC<`p@? zl2?*Ors&uWhlJko>MW|^HhHPfP_5bvX^Y+kDSlY;fve6<)p|SAc+)m^P zCTJm+XGFY+CJ$%MfbT1a(cZIh?P4duy=dP1SBlESxdU9Ev^plxh#)nhq-4^OV~H6G zq7)~rVGW`gltIC1Vzq$*|Ng-3muTTQ4;yXX+xgNAx$0MctDsfRSAQ%Etrq@~Mv{7l zT#4nf%sW*qZkaP6P&E`y13*G!6Hp^~JJdu0db54jus7C^Pt$~CKEMB_KkNVQ>`{lX z(Z~(Gf42R4=l%iLQ^5+>W(MYa%$JB55 zWxg9>Q6vp7*R6H`%62bP3QdQ^j65Un zciqB4KQhY{L6>5vdjQa9yC3^x3mNc}gD0fHa};%u1F3ukUsB~i7N!8@+iE*k=f<^v z*x(SyuK>(}^x{D+4MT9mV)gjT{)%xbuF@CfOa68liugm~ei%QsInxS)YVHgcT)X96 zN&`ZBKagDA-5>IOq^THegGdH8TVR`CUsM{tN84hZU)oE}Bm@K1-s2l-#ky|pU->X< z=rJcshsKx_@&TK1PD1T$`Wts9N`m=6c_7<1=I?ivtQ!J8n&{_q2N`_e={>Z9L<^Nf z!lLK@Z*BVin0dQ9y}pB8>tbF1|HGiEayEE!)`}cr2BsvDfkuwB+ z%?ocY?DhLN#FFsV-Wg$`C&mMH|IGjD_s`__p#?bv zA@ISnp;a}rp&S892&i}6z^LBfHq9ctckFDHpEK2`uv*>5gL>4EW6R{YjSZBQ<_46M z!bmdMYpUQ1Vqe$i#SskKg!MhoI0Cfeu7V$q?SOPy=@(CEE$&aMTKHt}OWS=o{T1_~ zM~1upPf484}>l|i@kw0&YbsBQm;4g!Sv^L+rHY+ zx99)2)$STPYaSfpBFE=zvA4ZH6=tlNunmXF6}%*RQzwkhaOlA4cse+@u*G zhy$|Vly!q325MebZX6YEm*qxeqv!m4c7BSd`z$I`#DHzXo5~cFU4ejMYZWtk#G)PN zuLZIH%lSW|+892A#M%l|&G@Mw80a**b24j!s5p%S2`#*LeEWe`o1}41hVG@A#GD*G}gPkorR``>r-A zBKAPZ!D*-M#{f8Yo+aP^|7o-pV@Oa*Z6z)pksWq8(*TzJ3}uRB8Sb7Fq-){83l>$m zg9p@qfCU@dLok!>e%oQ_ROt=f^s<4x7uWGqoZfD>ndR<-=`+}Yrg?)LxEcJ0&~-=g z`^_a+z-_2PsoWOjH)lUSLI7>1Y_qTdR*1JN``wh8SASd%z#symaQ9p94H@$~asNCg zAXfi2j;#OKxX}#_PUPku9WP!KBo(&Nw&aq@Su0a52co%0hk|g@`#Hu?*jMW1;9lx? z*ZS;4>nH!lAvXNq_wA`1hZC>G8-H27|V3!`o?KUHAx& zy<@KpD}8n0DgP{xZNukseWqa}eiag(7n!Sfj|5Nq^XvTaPSt++Hn4kfQpSk)I;_Ci z4TiSnW)(UX8**2H+gT$IQUxUohnM;SzYn#ihu8z>dxxqG7MJPozAYh?vQHeCKILC;-LUwpa+kt`CN zgpLjGj&$BC0Tl=W@eKn6Y>ExefahvK2~H?bRfaCi1$~UW0U}Hp8Sd}`p1S0}+nZ-s zcRir3UigDd)n4{g@(D{i64kS+X~iDaC(QhyjEHkYu>@e0#A0s8IhI_gB7XI3IvufZU@%r|LAO zpVZ0VgbY2vBst)QUxYp<;e*e&Gam|qLkm*V+a^p67x1gTw~vVPas3KE)^E1&|I8jA z`z5r{xK<$F_exU`#eZsbr{ebGEAR+Kvw zb-&~nE4{!+WcN^ue+!uIcLtCxu3b8&vRXALGcRIp*{7sE2}twb-(OMRUfgF`q8^K zBP8uxkRg|C6myDvpej}2>=<Uoe z$NqjaF~D+z9YfLxrjW4e#irzc5G8&BaPBT0)8R6B#3tVoCiFsgnpe!6i-^9>Ii#`JM1D?lValz=DB3Y4@&-h4pPt5T4 zxw4z7&EE36;H~5GN99k(0&fKIumSp%L=FN)jvyDxmLA5{>W6fdwC>8U_LN_%9~eT3 zf!J}qno++|;S@^L7@>o5I6jr#oT+svDeqzDMGrZ5bYV>Ww4EI`MN>nnwmpT-G@&X8 zARrIs}S%QKidivFZZUk>Ex4I}}Wdpv2t)qOI#@kYe zQ1mxfQ51inMIXXT@!J^R>#zAU%so(tionaTI0!*F#JF^sXJ~avXEam+^j{>#a#JS+ zuFh_9gn=1VqOC|dGOKpr;@__Qi7#&FDM6b1NH#~1CWcS*K!-i>UtizE5E@FVc)Pjs z;4(!T8>OuK#VvC68Sa0UQej3xArBLn=H062xO6QLJ$X3o?q-|)DHSQg0u6|oKZY;Q z*IpFeN}~Irmusj+((0=6b@bgzRRzn+Bd}-vKhkeRO5iWl*1!f1cQ1;e!u!`M$~?F2rqt3!t{ho&Sda$Xj=PGI#sbGP2t=i|&8$htn92k9;0I2DMITFDy z{np=5GuZ267&J#N(yp#Q9j+&>eh~|jFBnL0wZ-w~DQ8GF+xgJJdV-lmvzgm{iQEsA zu@8en_89foO}=?aZJ}n!`sWe5Hud!RT+dlIFx+$;c5rm8G;z^QOa}uQIv-Zv?(9iW>zfQ z_vW<1BbC13t9It~!JehjctXbm4CR91UH!?L50lKvXtK(VV<#ADTB{v zX&Rjful&6s%h)=YWt@wGPddzlL$fXka_W=s1i)!F*Ko{Gna$PzY9@F8Uhz~Tqb5H< z$dyqI^o*t8O?2dy9B#77vQkAx5eMS_F-c$Tpb)K0T)KuFb=(^XTY$K2$u0-prGJq2 zGMvC#Rr~F#FM+1&3<@*&Cs zg+M(RWA9#G_C4Hzra|-4bJNJA#0dH|iV+=qlpgpf%0LSv16pch)X{{(`cm!23(rE& zzB;@oj@n>s6(P?A2LeaI(ACp`%DP&Q!-px4^Y999+$08S=?PWry5`h0BB&pzeFn*K z?8)dUD=Q}2y-5yU&t7?J{W809*&2fE4O{nH|FcZwro>JrVKSe&#M0EZ{x#$Rc^~Ki zFQj|yeQAhhI5ei}OYo#CO>n@A`ll~sa@BbZe#XO2zr5ahp`&E3uLecGnxUnqUZ|vU7kI=a7$aVQjCla9G$uqKaW~ z$bHBp-EU^4^3$Q)t69yknk#g55+fV4%UJX{J(eVp?ZZ0*&ErdKYqk|c8}gGg`xW+cbyqZ zYN8JqdPx%g)t=>&g!XkR^*)%t`h{>YPGelZ{pg)ofYnNyd^^(BU6v+|{Uej^tGEW7 ztmcqI^QIO)y#oi|*v$*K7(%hk1D8)x8oScv9Q&(=8k^w2ZpTp_He9yrmG`4h@H!Yig!L8HHwYM?hK^2c)sIQjDN(NEve-6~u`gUk-#H}KKE3y@K$bhd*6V|- za)X&gKpe;hN#$xLZV}GuC$T3{pyWP!>^Ig7-#PH;3O)!=pdAKq85II=14*GCy!C!H zA)LgZvm-fE@k+@)&H1TYBOU6EzbJpg*AVu)XOV!%Q4v~aO7hFwqmj6>R-}rN5!g5d z#1}RFcfAHW<%UHt?*8h;AK?+&zIouG*D7tptIESxyP&<|Ds(W6r(2KU!9uY|;>ix< zacECzNRZU)vvQoJzLPGHBDjk=g_7<%j2dTYqg9APN>y$?htHjkI%-sq6ZCZrx;$vu zc1(8-=>(ahry|d;mvF_t$Y-HvUf91@7j-Cgn3=}idA~lF=dvw=mC|;)op6J%QqX<2 zSXV&@i7^U-0$wziDs4=3P8h5+4ch7CZ`OV6U{gzpPtT<6n{xP7r*NKO}(U;t)81@JDRo^cQY$K=fTt4lT!>FH6p8N+M{fl zL+mO0&qw;@nZ$F=W}|zK0ktko9=|C%y%)Csj&vz~0?Hk(_%IMBFp8mw6zm!Pc-U?| zem2okIOE9E6G<3PpS1Zc*Mve(?vX01h5feiO;7m^suE|=%XxW&a2(vfesbww=X5qO zY4%(Bo!0HN-}c>~?AAM{>`Zh^EWKg*7z2E?>y@2(`UKsKSNADDEfM=yz4}^3LuMgA8*8;eO<3}ZhdQ1+AA<|+pY+?H;6AJ&9fPMUQ1BPhKXj?ZJT z_FLGjy6^FRnyUs`aTNc43;%OC=b+hlF(Wl^D7}zFa@Zta;&|}ckDP5I@SPX>)rYJ~ zXB}n*rD)EX8valPRDQbXZ)D-{ol3kP;fS>6qbK=&iz5|mVG_j<3ZRDilqqwp$wwCP zR!^mP{PD{)ERiNK6iw}I)%NyoTFZq}822?)5^2gFjN(N#B8phf!^Gq?v{>le4j)sL zG%{|GN`32m&K?pGBX9tUu3Iag=4nhdymcW1qQ_PAJV+*o{zw!OhEkat;{iBMIAXYQ zvxEL!ex`iDkokzY3vt*3xhIzXcvkXIyI_~LUFlDgdWJJxh0Q{r+G3zWCq2?h2Ff(~ z&dq{Sz5rWAzh%Yc&%mu|oAlNn1ReES7BI5ARRo0BR1LI`va>}9<=A}g3s1^)!Qzc; z3D%a+bs-U^=%Z8{!@nke!=Y3@`aVjE1C@H6CQ3`TX&A5`XF7*fy96BY&M@H`evdNF zrs@D~``s%gP<*6RME%v@NV+=hF|lsLXGV9UAIgP)stJj14*gm1sPxgfn!Ma*^khk7 zObYJ7US_v=97&0tbSRnvaFli3MRqy8J}P^Cwshuf1ZO*7@|2SgkvEPTr;}8Na-(jqx!_P-Q&wOt}j&a zD31F58oG+)>Ec6Ns$q!ym!%Wmmi{zUEfh7(EX2chrNQ!(=}_9SKi1z<7n=rG$+D0C zj~ai{zasqp;lridNt$82NyVqR_r}xrQu=P+i3#m9W0-?ivZEJjE_a%qs1`75ecd$F5+h z@T18(qrPfo!> ziJ9~MC_mP_6wwbZEc#>&Q(;o{8m=)0XEO&Z=xXR`;D1CSI zd5FKGaHahK87_=(bq=-Z7juMJPeF7yv2t>eMm)33E-3t|aS~bce)?rT zmlo>DcXTusQ`C5)=<)Nx`bh*gb3Elo&$aZ-ZDfxMmO{=OL1BujQ6r<8^eWI=$TkpW zq8I1|U072%4MF`_KWu76@10$GeZyb7gSga@>ELwohgqX12YD}m5qwYGN@27J%O&ly zgGh9Pg;;OW?gM-Se4*I}!B{M%a~8awEX zevV8+$KX&x-uUuvvJu*yHK8rkcrjbwbNJ65WRbfS?SVqqqj6q_Qasc+iz>EM_u!9b zNa~r5?Zr#ABS&jBvY!&qE~HlIaOM+Fc@Z?0%A_t0!m(`4eQtaITplR4iFWf#+fsk9 zAE-QA-I3(UFhFtr(_2UXk7cbi*9BF>Qh<1((B?v0OH|;;y}nol8`_m8rhK|o%M2cZZ5$-GynE;!MHC}t9~CD=ysc- zQ6N(eJhpn56cOX3HKP{snf1tBY;|k5OIX|x?fF6Q##hui4vXLd_tcy0_GDMOG;eQF zis$>0@>SbV?Ugl8#k=qhv>%6tmni5q)Y`Tcx|EclrtVTM@B8sFJ8|}hGdU4?N=v@j zcz+ra4ec8Wxq2$JUBnAPv?dp-;>wrN;TIFN#CsehcpY@Fq+px<5TudoN^!x#l_4%8 z>8PQBO_U3!J}QwS{@12)Nz6*+YwMW`rVjiO0{9@K-TN64)mUKI45QpzKmELU;$!xu zOVF3nE|~g3;uKAkcchD8F))!EQiQakU#McB=p$EVLv!_>WH39`2+BZe6^_3LTgmnv zn)``D%&-u>&{)LkYgm=`ztiP-w-3*TcAvu7HKV%0_)f2mPb+Q}tLYrH>6Q1Zvnud6 zFArXuKvINjn#6I>;awG+AY~HcTi?}uvzo>P5xu_Zl}@k<(YT89B^dZGhf%rbzg!EI z*MISr8@h-d64NfOj#2o%`j*QnUgWpIBev zW$G^{4YSXk0X%5!`PR>*P5Q{wQ+L;9W(79N*|AW9ZZh^e%gu=G@2Z z`;H=H@G1+H?Rvo3`ojeo`_LW{O_B@Y8_k4Q7PzIy9->QmG3yo>4xn{oN) z;hHCHlbMbk#4G0RG9P@>Mb!op{!PRg#odk6z7E{7u#X#Oj^g%}+m zl?>~VQ4aq1>BsR0MN5QTGf8=d8GkgQmRcu zK3lJojW_VOms{xU9mKs4t(6kD!7*+u-Mck!J|*KOkP~I(HSM|ASAdf*kaA9FgLfT> z_{3EBw^PUxSG+Ql!B%ejd;y4##d*4nkU=*SO6EA1$3HYmoip22hz@ahd5KfX^lYaF zx~bt&y97QOxWCL3y`h!RAZbTqYDTONMqK=J(cjz6QWjM!Ud)8UJ1pesv3}iI*PlCS-F6dV~KTn@}W>i zeAc5I#}59coD#~M%w$}chd%WqagO|{$i_aLE3O*OQHGCquY|R zFTtHzF0kced{O`_9smm-=Mj>dPlVoi+OcP^-YZ z45G<5@`(@m&Z`SE1Sjr+=>+vB_uk6rnk7qcP0a|_32^vMhw9fW*&gOE?0gJ~vCk$? z@g4mFtgwUK#32v~v>$P}XS|K{_Met1(8@OgA97w%T&n7epWmOpj39Q`IlA{TJk><0 zC0@0p%n2zO)gbu~u-*`DaKe-`VIi}U~~5g#zn-LMFgZM|^S z^p}A%DNN?Yd!auN>x>-C+3!FG>K<%>jvH?n$=&&xB`N26l(d#lf^c7QV~lKmLU1IOBp5n=Vl zKR%L{RJCsn_2it3Q7`6ii@2~RmQK)SLRW6+0lZfkuVm0h`Q^zjvkspGdA2i)+rc!t zM=0{UaD59h;j~39an8@uieo@Qe5-UmvJdpJ%J`=3MuP_ZJWELjj`ySE{NLUdOa6gE ztcgZInC_uI*|nK5Ev_ZvOtCX08w#H>?QZNJRsfmSI)KiDgQKMhaWaKbs@Obi)F;J*NcF48KsGQFzIDdca_F%(oh(r zlo_)PPvIR*31YVm>nf;?WE3C?H2vJ_4^Y^k`bRL+c{G3IUei zP2%Gq7jWi+=JOZ~@YUDFw~WjS{?E#ePn=55<&dg~Jmw_Uzn0jUZ<~s#MAHB*)grL+ zCGS)vNO8Z=B#_H)2VZ?y$=CNG93B?{?gnb)E43KF%$$?(<2I(`T$1AiE2N@qUAokV z%MXOUU0&+dStGUk)=}xar9Ni4s`UD0W zy-^YcDTy`nKgqHY{S<(@$kEh^r)p8F@SwVAI1cD7BrUL>>}Pm1X-oyLcU@KMhzJPx zJL;7wjNpDhllq(LVr%~l8C$qwZ?zlwxx!Q*PTiRWcJUhiCn;}@Xo;E7&2M3uibX(< z$|@Z!q!)@%Kngtl_qnCz6YCf3Ma~pVHWqnT*;QY5?joEPOk7ze(Tx)2(v0p_yoh;@ zFB}Rm@(Ba-WxW#rXpOtf+5_p^J(3)snFggrQdwTWT-K{o{5&DZSxVOaXxo;aU+`>@ ziRr1t!Y_)uh?kq41~gyZEkn3NL&Tldy%|q&3!q+9+hquqTm(PTHc-*$3#M76c%KT%T|3uWg`QDAUS6{cFxN*UDc?}UJw-9r zB2Smln?K`*M_`W5Yj!_WPB$|fAG4DCtE?t7QP#bDdWTPccy_&uGuUyW=@uNYP|ri1 z#T~Y=!+S68K*P~{i<{`3{P>0X{NOt1bBBQN@C}*7HwGq8{OYgDoYmRM=S^!7lD=n^ zXydt;C#4+wPkAU^7r$Qbo>_)FBVRnlxL|2Pw;< zt$ahoDzuT6ei%9A41yriP}k{hSISVF!Rg>p$ohnx^)zVQZGhuT3eWc&3^73e{DupX zLWoNnOl@WrsDcK1a{ZzFf%_uLP0}&gF@HQ2tExlmknAGU^sLwV)3a8~PZN9XuQ60h zus9rqh@157sJEz;#Esyb5dG;U-}>XrL|E|pWs7p!6mhP?_Wb6805IqQTwee~iQy8a(}!3o1r z$^ou13(~XD+L6Q+xFkrGGC9WPQ=C?ttCk6PT3b-AiqY1|@=Dw8 zT8nAua(L)2S}BXbsf}OrQy#3VltIIzz*_Ed5JY7f{YPzmZ!}EhL@5lMu05$Xi00CL zR17+WE-t=?KQw!X_|vKy{l(oCh~Y<>>)(u_d2a5BfVUd)w$nIsP}ubQ7Wj)!;Zl@7 zVM>q>3#jnYKQC-CNj4s;!BjQ9oI@j;M-a+jOU=VyMNS02)t<4Is3^tvvqzlN%kB_m zt)5O2jj!*;hc#&0RJS6&H{}a@<9G!AgJL?77a#8{DKclg)zjg(?XWXfA75+}wa1*I zt+&dKzojW>KO&BZc7j?O>_f895am0oQbg{MJbml%K;>Twm6E=FX=||9C?vn(R;eM% zJBwA}1ygJGacK(0yMqW1oP750s4bY;>Zfe<-;jcf8_xj#rU2?7{f~vKBJ^~iFKrk| zea76Stysm{lK#C=+>%e683$GB7xP`&;=`$) zz(OWunsP)UA~?b-7vr`2dQH`ppk%*YQ5%ehB4S9b*#Tbgz(OP~`FU*0^M4UF`{qK_ z@SaHb@@3$8@D%rMx%6~wbc2zY%$MuQ4SvI?m*m2c6zwIj0r5HG^GXl;xp)JocNNMK zV;9z_rEG>j{Hsl48+`@Ho??I19QCC;yLv`P_aPYHuc3G~$cejJvXl|IwE%=SmknJH z_z)7IqL#Ntta6&Jwn=wsy)fQG`rl`_aP2XkmNWw$0EZvdge4?4F#?w7No*wAGo4QU8;lHnJH30e)z(cvdU0r7~;lT~4bN9t3{GjED^#(jp4uEm4j& z|6*=!nJGi{&!X+cF%c0daROk+r4INI{v)H{Qepg-hJ5H!g!k%26B?a)_<^IE(2|)O zz`y5G5pZGpy675#v9bJZkIF(14V|`)%xX6%UHS2$ez3hQcS6kcXJFUkF*islhfndY zBjZtcql$jQ|wo4c| zd}Quw_cNJi*4ZW_GAFsnL6&TRRYJ2`x2R@3m-?ZN9IXH|_Rmtz=LehA+ zGQ#dV^8DMdZJO(aT3=+KKR=I^*HA2cx}SstBh|=~beeNJnKMyN!-~4vnB9T1be+w3 z+MfYBA|EJR#qD#3NFF`<>{>!SNd2NSITfZAuyjR_#Ew?lyXR0h^YqkH%JQZqXY@{yl%}E5q=K;TYsG66i_Iq*&t> zNeurDo={mhwyMDsgY&pa{3Zm1lhLr|99^v!o-`QZD%!c=AUaSD-w8f5Q#ZA*h*xrL zghy;Fuzx<;CUh8mdYfYW(Dgqd5P!cH5L3s3TvgWxfY7S#OCp6jd|+gp0>j?U&C5-6 z^6mS~NpbW>BL=rqRR!HB!Gs*Fc>gzYlZ)o03j!x!jqAcI`p_)!xZXmF2aiEzOp@kO zCGYj)>0;*vDCn8m1Hmo(WJl}@I+URz#?BBPuZD>< z|L>^7V_ip6_IUM!BH~p7+_gNPzHW#AQCZkHy5udp2_oL5)V}eY#F>wOWw=WDA@N2^ zB9O|5@4oNj32%>h0k+f_%gf;VVbo^N+w)-RsWxR?qJ)s?TaQ4>&x}NbM1wZ0RHKs= z54eukWGmcPuvq9(f&%SnMD&IxNN#5#Mh*+X&3I+{?$PIP0k0}J;XZ^Zwjb}o63zmNgv z_HoFXo*#BIeVQt%x?D`%0OEluiPK=h4X}!sF7+h)kh!y}MrjA4wfex{MvnMdHn-MI zzla38HyU?us39JaLjOA@|Bo~lPLA?O7K?+BKB5Edum2}M`X9Rq?r{<7F$U|6Obx}P zW5i3aUf$kF-9lt{&T*LMg2gbtrtm9R-+v68fM|-mt?=M&xCWW^>RKD4A$45@W`GsN z9W{X&5NubVj*1Hi2)~1Km`u!P1!C@Cbu3#LRhV3BkVp6-EM9Dhd9H;e(`uL=DcJq*2I5^ons7d<|Gvqe=M1~|ny z3>Qc2#a+}W1i3zN!uO&m)vFr$k4#{_&2aJj>bS;*iykE~h`muNCN`Ro8idh(d)VIB z$x`2=ELwMsm?CM`E!zrj)0fkj{kb##HD5a4_5JXCT*ZTb^OqjIeWWDgJ+a<>=TXs- zia_vb)hx<-Bc1d7eolwJLUW*NjnKZhzmG~T0{s`Z(?@fT0E z>FeV>qtStx_X}J>y%9EU8mvTRBe>U}zWh9R&ospozcm|S;eCh<-=JkJbfawiaVl%f zXa4uqC)!IIRxd=HfcYe&rFfvD(m@2I_+FVRTI)^>Ydnke&=2(s?T$ge1P-Q6K9}5{ zHuLI!9_dm|T$V>Z^xMI6Yj%r*bg36IMr2Wf<)38+ZhOq zb?J>gTsz7(?>O$^N$TA{sDwx9r_3hdC~hMwuXnNKguY{gqkKkx(3dAt=5JB65#@Ld zqm;0*#D|B%pG%chWp}{zYp*p|8rGF=?!&u?DO0t7#^kKbt2JO7sdNdbIp(g!M=BPo zzkDFx#UC8j8uA1*TW*EK2nNS3-M+6V_h}3HD-spjqOaHZ?R6q~R6+^)NY~qSk=3ou zog{UUU80<}y+=X-a^mR~bnISJxC>WpY4F^qA5M1JF3&5-*ef`x5FTIS{-OvK9ua0? zK6C#a9(TN^I27T5-2Gk!L@IRq;zmMbHnIw)Hm)3QC=36^%~H-fnf}x&g7Q!JzX|Ef zV9!#N37Tei!WDIJT5$UT(hIz;{MLVl){1^TUPdR1FS~G{-ZNJ38-9e=A;diLH(TJi zWC5lxgT>ZZ<#Uu2q~jZl<&9&CDjWEwG>tuwx36??kbpqiLOs3ue`*3>ps%lf$hn=F z6rhogZ{jp?r1tJu@j2IU_{KNj$DO~{sysx?_9i*srmw#liQHFtHDDh)T;tM(^Hu*T z_EBxwUckF?C=z-nQe>>)0rS`5+Zlc5iGT{(bE?c(>4lEt!A)Xcg{cG&<31TT|T^?PY&)4EASI@;J6 z@*97cni*?3=(zEqFNB{#(Mb9HVv&1>dpZ1g`HkqKivV!}m>=G>a}cm4*5DVVt6rOQ z)cKzdF004=*pRWt!6nK8g(qQ^TdIC|m4TVV(E1aF>o}R2QR1wKHwa0>+8`TV(aUS{dz5v{OSA)T-+<2ItN}$I4&x zUng7ondX!tn#;2AU;Kk3bXQPNQ~_gf_#;fOL=bSUWTaC4alv3&P|=8MJbTtv)nRZ` zP}Tl$Ff^PH;H=e%0NDkN>CS*x2EkuBSCepg<|$?UJu zjV;uP$}3ez{C4bg)5*{q{R3ED2`oz>_aN}LiYn_HPKcB45e$9$7#xl?9j6Fuj@R)n zyI4ocAmvpy0GNdw0L~-Q#1}mdAM}b$6{%jDXy0!1>Ecy3zB7;>n}XCXvHQ{qBah~m zZ-A>$Kq_H43eq~pAy<$L42ZZO&ryenkd81}T7NLg(5wUh?QEZQ^S9eChZ_0r9B30w z`*D=Yi!$BsRZHN?M{^70a+)8mN4>k%sOC!@ZLK1=5Hg+NT`OrEq{qKV%G#gaj9%Uo z(J|itt%0^=dJg|yet3GJd-eoud+SMC2#sPwyCa~Tu3UztpO{^Cc->x^@=I-Ff5L8$ zbF6onkvWL%4Sv8+tlt77J!`C8W7>7Sk!GHLIHubeSwaIoM>@mNz z)9tuqOln(yX-9nlnyj+|Vk?dzFHtf$%PnBEMG!wvfnzRl0-4T~TFM#f&RjSZ z{12!Y>GIq|-1 ze+jK01tHWLp~S&kzefMX3Nc)-Oy26;AQGVPvqaq67lg6hDy8rk^C+3Mz*3JAZ+O5# zprUBhXFPy1E+7^jug*PBc@MY4`wv=oab5wh6lb1ZFne5_!-bNR)vOp z%q+AP22ZWW=XCf5g!x_{D21Ay*+L~?E?g5Ij7D8kA@?PM#EHhaJkNXkhtnCL0D&Snu?XI&_KV!=lxk0cUBAqg0adcVnLk{RBY2N(ahLv>pq|vV5g3cnzKmMLb z`z?hWl9h&ZfOl#Yiz$sqb@g6iTjo{*yCED zq-TutfQ-ZnVdl?Y*b^@k50&BJ?R@__pS=!||4YH&|L%u#wY}k6AHYiV@MihL(arDY z_>gjpwr$!4Q#6w^GsOE9qSRaoI(S~c>AY5?8wg-_VS=WN;ET4)+&?30E_uP4eN|D3 z2Bw9))c#>_`1)^1pA84ucc0256j-bT6IL)m_XY*K}P>;#BnK6J5`S!gv5aD zw_Z6lS+gTDjofm8mng$yV&KPASdTBV#o_+0Ns>wEukP*b4mAAk zq)gqGBN0&S+Dv|YcRr16V?y!NxgEI@VqmYr!)LIHVc>OWQsXk}qNFs#+yD$}4BC3} zpwou7XdmGSH$mx$WuCoj=NgR94a;Ip1#}{7!GW$Pwl?de8+_Rx%M{x(c647ypC7&- zI->sRXA1&thJQ8if0Z79{CMVUz9d8y$~PD=cfl)gk9y4i2mK1Kz_7jWIe1&uSZKxY z;b`{>`XyK2Jx3^aJpSH)ov-8h z-HgmMJEg&+{?cqUM;bkRa?R_MOvTY;Im(5TDkg5IG;8k^)vo+Fgl*(EEkIxLs-%8! zuM)LKZM=UIsMFo2a2Fd>{*_nJR<+LP(%pTDF=zIb?Tq&jKCWG2@602G=URcu(G_YY zj5z{J6gbswXSuwjTbhgoHK+}}Rrb^ogYUxya7lm53Sj-vwHV7CYUS-?Rh23_Pk!78>tGnByhiC21#m^1<8&Kk3 zuT!@#Gsf^#{0+ZR{T?1)iQFS%%&&Vo4zw#KWJHQ1Q3)?M1#A3$%L8H6g(KDIz-qMu zj;n)Md&5Vsc181u#K;PKu|Q2O^;Y!zA$lJfn7Mz-dmoU~ZmIr09CR{x|17$UnqBKD zjJW|T!9>?~H3p&OPYujPeEC&c(c9f6-ovWgFs+Y{tQ(n6CaSL(qJ3BUWyxa*y83g1 z0Dp4vl0UezMm*e~7O;x>89$K?9G4ecyI>+SPQ8`Y1jfh$(bo_k<0(G!aqzcsKR5@n z49x&fykUH#s>JCfb8m59IUM4dbU5+38+J;!Q0oM*KLxAPPgSE=4m60*Ws~ zr`ebGuv*N5qP$&~DW0uE^C3ulrR(BhbqSq zO8OZ)4-04TTv&sny%6j!Hx9UhgjH zxZwT1B*>(&YC;Eh14$k(W;-#mHD0(mT2t@59 z-(EF;og-Y&WtX*|uGLLECN0w{`^)|Ay8oZEaA$wn`vHK(p=E^+jMl4z^%WX506Pme zDql3#2EelS4uRWm;Id~?rvHxWM{?YkTz`YkLplpyNUNy}Kc%|W)*-MJl_P!U#zis_ zj2{%~J($JaEblvu!z+B27sB z%#Pj)0A7gQDxK(4QZp+5`-tzAXW2ITZ__oN^Gl^Fo2!)4{HgN(dyPE?^b)Q6e#`yZ zp{Tt+R6j+!U^$jizkOuyqSu=F4C6`7{61*jgO!9o9=SUx`EAx*%2WL>b{~W3TUB-e zK)2SJ@enkP8loP)s&eafK7aI%THF$sOLORg61;uqsa%>sI+#ahqozDKTmddN`}$5o zm@W!H5>r<(s6PtmxgJlG=L=dWEW0&pqu zcRHn5eV9bOr(?t7KhD_cO(AjTn3WiT)T&F`)TUOqFCMjdTCUbo)VLllb} zz6;;>29Zs4Dc50}-eq==C^O$TQbQfL+*)+)&wI*rqN~EYPl3P#3RIxo9M6v9XoP#g z{)58Znf6HbCw$FJ8QVL@Vdn{vWdXHq8Aei13ydw^^ZC(zRGa!It+cC(klvh)9v_yq zWZHNlo~}-GnwX{NePP+^^EGr$f7i_kx*GP808YaJCQdZ6j7L|X(4R_VJYTQPN}g?^ z9c`u03R&M#0xh2q*-4YpeM-k!^bnrb32$jPq7}Q-yruf`f@rzS3Q)1N6v08#Gyq2D z>mhWkzgI3bi{)Cpji=u<{`FW%Xk-E0HuWmw3zlZm(+^JSRvN$G|IQCrvk8=brWHR& zyvpZ6j_vR)3`@Xq7{ekjU0<8Hscla}3~W0X81Ea*DS?d43o?#S?~%y+t+0dlk*NNB zXTF}l|3bNp^n(?%XsP|}Ct~3l3*{yB&Frp!|0I^-UU_n6ij{X?y9K@KjNbn7-7Hr| zUDI!0kbE;UCf5fI&M#eBmqBNq#?tc3GINxE`TBS!_?$+UEL<2CdMeJeVf@YR9Xy06 zKUj>!&F^eW;)_KqE4=cbjQ^F?-$S)RUq*3Gm2}QqUhZIPrSJ14w@ZB;Kjgr}r!#Li z3ZyZie9|2WiB(ni*;vY-_AZPNoS#745c5_$lQJ=%?8dz!qo*sM+-@(%(~?Ix2No$9O1>iJ4D)c8qKI&HCHuaA>3gJE9zfBM<^XiB zrWDlM2m&2Vtr8$#LEq%RDCu4HRWmyLS4heZjCG|s4>41vxiO|Dtw*PB{dHkk0`8z7 z(&D8zD^QNMuo83?uKnDESS2FsMcOu zo^gg<&b)NDes{(TxdRWk7sjO-=RDw9zkM9uMYTw*l`L@1b!wXe=<7UM5AqMYq?P>j zTz5;b&-8%SSucdqxmvL3LSc1B`4Jq{NADE3xgH)s8P_JF5)t?u>74%a_9_+_Fen1K z_MOe^G!E4kqdZ1B^Rq8D9K`BCC!Y~NNT+GbAK643`e54Zrr z;;t$U5?x+<@unWGFXHP^wTI7_Uw>#s;wk*BG^+rr5qBGN_sCZnz3N*4qGNXd)uUnF{1!IRDG5hgX2ySq31q02!kIDGQ*uuE|P;}Zhsa8)W zIh3Xe#nqVK2-7C;ib{R6t1UKydIE87+8p zr5jwaepymmfpFc}l};j2)8R{egN<)@AVSvCAkXca9b2J(#oGl2WS~;wuCsEZ^nQAtmAT7N z!15<4Zj1!y4AMA8fRRd6BrBp5^R)$+Inly@>3*kiUeNsje~oj1Fx|YFvQMi^erZL6 zMf+EpYbA}pvm@{&>|AR=g9@s$b*0uSZ!H|!WzLOE)A3bjM)0yhIy*w!l4&mHnYZp8}G0s8|`&ITn`>t zYCRb!s{HiA`0Ckm0sNwyr*--jhbgKU(`W~GPk(r=E||lq=p3_E3eMHAKV8ggA_^_S zVz_<#b3lQFnHV`97^`e8EI==bh+W!epgaymS81(t>(+~Kd*HZQJF-%~mYwHBR`Nj( z(#TS~qh&VtAYaF0#w_H$-Vwn<1Fj!dm*0{-3fovLqn2^Qm*=8M`UlH;t!Bxs{f5?l zs%=)Wo8SfG>|}qB(d8r~z_t=IFMrw{I(}Ak-#i99&vd!dU;oo6UI&ZD()Q6#^$Bg@ zES4oDJd92_Jn$VA}xd>miHY-Pf6%R-|fv`-{Mc8Y(9+)O!8y)XsJX=;S)az zoHooF7E_4139z78i*Ajqc`e8qPDJ;s>$4nzqf2a4pBr5PVZ@X06VqE5^ox+uI+o6x zicXZT(aePYiw;r&avy#q1l2rxenhbntxmrZ%={t!*LVCn0T`W(q4rm27sVFX%TZ>%E|6> zzo2JHzei2@tI+rPIvdpByz0uJQ{*0vhY|F=or$;18MPJA1W7Rp_ z*%||0do;_VQGwPK#fJ;MzB{*MUhA+2jxVBOi#ICi4h06Sb>;+J(Yrxj3oDaWEHV$z zy)=UmaIX!?P<4X>9N^v6io^>i0Zks#d*DrYB{)2k%dYNx%;PfJ7fTP%tho0S`@yPV zUuO|8+(oTWIn(QGU7b(_CpcMq7-3Ra`8gNq?FSA0Bt(B87+VuWoBzweg1*ivv(cN5 zaj#=~R{-tM*vkw;wYvVCq~{rZeeZt*>@Iv8{B!@@R}oL`y?H~=(-#;!3cTrD&0$|; zHVa(*9UP-}T|k+|zCahP{wjDs6ZRoNO3%sq?-xEFf+f0Xx{;Hf1r#wu&F_CiMbZbp zLk~lj1?`@VGHH1v2hSa_3sfAAPUxy!5YHhe3?RMxe#i{hhj_3Y<&;9-qsIp+*vusyDG`{D@MCV!z=2i-eu2-2s z#zj(Kycd5`!cwpNMFAoeQ3x3DFXv8(@vCy%&`L_gr%GO-BsDOXw9EMrW!J{tCj~*4 z0gsMcMB^U4;)dr|!G~qiaZ_2!?x>5yEojqBsak?cACIDG9lJJ@Xf%Vi^Oqp;;lSDh ztR#Q+veWdi9%m^iYOSfY@nTLmI*%p^^IqHuWfe;~jc#3e`gukk>?y$M_2EbH1=MbU z>>b`T3ssILUu7ek5`z-8J-+TYYWm!HR_70qUqCk}Jv4S9y3Z=M_lZ9E)BUSgF|pV>XugD5Q;jH#}dEzvkqrw=8vizYQFE= zytOn*?dhNVdFe)E#SV&5%S zFo2mws->wD!9j;LO6^U;q)q%KpK?M-Fumj48(ZdN zb3Howpd)0!p4tY9tl~vTO@b{?*m8~EQBEw6!TH-(!6dt~)Y>ad)04VztS^Vv|0-A5 zR8wEe5rlu^u=mAMc^{Sg;%eW637Q8){Jy0PF?r&WsyJ9WK!fOB8&I)->+cu+I!4L< zYyZm}tv(zn+lAg3ss;Z^9M^+7=$F)UY086Fdzi3Y zTi88|PK;^l9HQURMZx9ERIgUALAt+SRAh2%H-)EuWJ%I%mRUi9(N}LH33@FGnn4bW zzj`*t+qFz|OTe3u;UtauE@^a}xd&K#U|=A|ne30(cPemMWW|L1%*I7%h)mfqHm2Mp zR-F8sV-X5|o-)y6-%*CHse~ZS1$Y3G(0DJI#7W!k7tRxNg{p9*>BGB6ggM;Ed*(iw=Sfr!MxGO4z z;Ao0nus8ah`m+1`;hcyeJ-^5#U8L#90j zjbD%!)W5})?EY#B_ALwaUX)Ra5YgI?>%d*lqLJ=Rzc`!Bmg_Q<$Ve$ufG+nu2fCjYGAhi8I?&4 z7$N`Lc^(04h>yVSW#2yfzj5agti|r9SP`XIbZowW{{)8 z8|+y=p{-}bqgtp2=uCG6IGV+je>18bT9m_%ix^b+ZbY?7jwWJOvCXG2|4_TB+L$&# z`;I*djAUKBM_I8)J$;PXk#dg`Lb85gJ&F?Er@K{l-ZcClx9Tx>Cmy*03=>2J5D_#D za5+-`{KiN@A{6zV)#;I&#U7*r*=+ufu15NVG(8KQnUR~`Hi0!sDEjpAM1xzEGkqKY z_e}T%kmXO@yxan_h?<}%>r6~cTTM(*-MIbe{@_tNHrV9X)_0r55F9Wr@*_T) zXniYSgQ$COkyM$F4@+nd3Y&cWkrEX^__UXG zb7Tvgd_}l%W*r#Vv;WWgF>&XA(3>zMHIn`OhzMJZS*rnaKuOM?{ic>2l}-y#5R9ukYYqV5bufd3bb183Pb{J42jf(n!l&Z@ zLXp6RLW9`}?^KAdc?%8;h=n|A6Rm{sj}Z%-tZKN2(A|*_)js}0WUGw3ww#(FbWVjq z&s{!;YviM6THkZ4!kaoOcR&(5y?gc*8im*7DJh!UU)!$rTYmcqr-cXQ8uNB#gzIWE zyxt;G#lM1F!uwAXv!1A!j8!AkTLogx zktxL2`D%FiG39}Eo(mTrVg7Sfj-90DX|L?csPwq2Jny!i6$e?j?Z^tsM%MYmIP-B@ zJ1P>YTac8toYon^TGdwTy?wqe_yG6gKP^F18d?dSZH_i=6E3B*VX}ZgMrh=|wg;J5 zmy@3^IU922k$w|yx&S}xmz!)~>EedmllxGofpnQFSlvf zlcJHISr$j+<0)9My(~^P*#pm|DDf@5^yN`B@>7co)8v@3QgSOeKgqRda1%*>X!DjV zvL)22GE$F3%D>v_6lNwQ=P6+cYyYzO38Q@JeU!T!Bt$`or(i1gizj5~@n#i$uR~xLJF12q@H<`v-?f*(j7!7Oo&~#45tb0;aZR0wlcnv77va6ck zd|QEq{T`0}Jhdg+^O1u~b2b!y5rVmJ%g^_#wayf%x5XfNRM$m6J_%B{s#uCyapAcu zAuuDS_HX?*wL`zq`~fk;ZI|^YFL14Xon?5jC|!Lz9^e@D-8jH;{IQ_gAA(({jU-aX zSz~Xw=|ildLkEL0g!&O&4A(Mx@HaSf{ZoH}`aNZtn{d!0DTxixVglaY4>&lcPUXe; zZ7`7I@!cgGYTxTVQOJSM)?NXK$)Vi)*!q-+xD(W1)8&I6nAAR^nil@KzyA<$L%6yK z$H^>1A**VypFU^`<;#@2pl31AW7?X*NZW$};-|DRE&hm5yUrATTV-=&ONo-H$ON@N zZ<06DxsS{QY##vbE9Oz96nq>Nv8=BmE?xo#@AeNmn3G*1S?~846ms__*wKuhIi12v z&uO-TZNM+A22wr%9 zbgc8{nWoIdCpH_!5Qf%#+}!-pQ>I1a1W!C@Z{DyjCy&lE*6d|?+g66+L~BM?xg4Fs zH|&HTt8|*9(o~W&jM)OB?o?w;%CkNnzS1huHfC-N=IbT2u7T9?Z9_F|J5l5x!3h%< z;6Uw1Llu&a$VPpp({ZDm)@S4J!KW5v3|@opAwRZ#n)cp{Mx)h?+x)BX)| zKC{MaPXW)0JEIKpi}KHM>n+QkuMs+a2BYsWjH$q|4Q%T-{Rx=|jim%k!~@%c1`aAe z$aiah5~n0D6r&B7Y)3ujbSU7avcoPBpWob#L7UEIl+LOP<%$B+UK1@M4OZkx2+;TC z^)0DXlgY{HU(3miDeX%4@iNYZvi2Jpv{C-mq|l*Y-M_mcMi|2) z6N)T~ScP7MxK_%gYpBHN%o<=a)tdaJetPB(Vl14_kZs>4F}9=s-7P)W>4L6yuhN{q zAQtL+3e&=b{UKdmgGFrkMJ|I^tIp9HZY1(+m4XmHqc1SH-FxRa+jI0u_Wy-US#!4prio@D8M* z=d8{Nj_Fgeu{}kIeZu+6by!PvC=n(3ffmgAI;it($n^$}1OR`9IwG<(9H`H5iYC>| z#0XU?Nmt5P0G;OejNz+QjomD9tC=8@`*T6l4Nj8Si+w=$wdJ>O*4O1CIQE(G1}?== z+w)Fy&Uvt@T&jcaN9C8MH`LL{(bZJ-i|Dn6yEHJ-EGSpy=(LOOstwrMyePgZXz)KQ zwQ7%3NhodY_KOy%4WEX>Qck9jAAi3r`bx>e$B=D85HTfEeQIk8{K^YDz?ZD$710 zkPt~o54EkYxZAtL@uf&56nKL&jUVJBamvEpm}Yu#Tkm4DJ=T$R(8rQYToVLmXk-Ma zSH|r-55MXka<4HU%3I1EbgsoO zUlk@c%lya0eH!dXV?ITqR*QpF91$?vxIT~~{2DwF8iX`id4t(tF(HX}SpW61GjN_X zm{L?|MHvfo+W3W1b^C+B;u@s?__QTq`NGYWaV}FZ8TO+oExqR)65tj@#%d z?}6I1bjE|v(1n27hYcTHb-vIp%EvltW{F7u9pt4T>mjY8iID@$NsLp|jga}5T77BOXQF)KHTCRM9<-TCe&$>SIkW;z^w zNW(}r0KWJQY2qY{(XuGwhS3f`}srRA+sM+`Kf*5uR6hH+rp?`qjna>g=# zBBt#5+;+)=_mBAW`A+Hlu;(s+T&LZb$m;2oRt!#SK-n5C#o%dydDsyYgDb4FRV9u) zsmJvGishK^31*=G^QMmtr=_rP=t=Fb$syINuECR7N^N>}m%c#O6Q-Q^H^0O)UGTt6 zu69D|eL;k+9EO|{)9Q&DpadJuE;vrH7f0F&VPN-(El|@CZu`4!ui*}5Z37OT`poyL zD3Bl`S!&@(h!?t2?-wx3Q1qrZfa&cm-=_V#oj02D8JX0!4`WKnB?CP{KJitgfvzz3 zPx=jW`od}9C;iRTWpjVnwajV4(bTRWHobqfaaXj_q}+C}+6Nm~#`Wb_R&1=VE5r`e z|D0)m^{o6>>a1i*W!0AG-6Hl}JC1u+g~jwX%1Oe7B*Z?%>rVRg{$d5%d#{!tYzfU= zuZ*HoiXvA_!~*JKV>{Q5(XkAxY+uL-$_JQv+trPFEPl5 zD&63UYw>5n{7(IS$g`gSe0=17i|MAI;;&aB)jowgB;wh(b0Z6R3&>KUFT=oDXq3=| zq~q4h2gKn&FxXJu8(hT15V5>?Mh2}bv5qhXGFN<<_~_MmHXVFqw|p*JE{1n#45J5| za3!J&MVMk(zKU7tf!(5q1eF1doi^77VekcKSk zRcU75iw>qI#(k`d>+50f0;$#BDU3~OT>e8)acI6QH8JkA9lLhmEY!#m%OcG3#U!f2 zmjdm3cLZeHm0=vJoV=$`9@oUyV#eQLORQtM0`c-)&7sa5ihjTaMhA}hLIkQBp3GFa zFFw3++5jdkt%=@{%qgpw(CXa=a14A?c^a$x{BpI_2rW0p=)=2nb6l$BpQpF@3nG@h z-~MUe>ygq3I`e|CL>gRsOE2$7m zP_K*PUc_9o(h-_yC?OzgDR>~>T~CXJT|bg)X2S^-1H47ZlSpt)QsF0`fQ;<(oe&(v z(XEJiE_6lH{@g;1T@>xa``q3|4^*o0GoPqmi81czn=FM>8VA7wSRACToBbRB+*am; zjz2yC7mQl*x_3I;?XHLyLNdfdVVr#>)R;@%HbUrV-qoT}vWW<@T%tsZ(X`fJZ*h4n zLM4h#1@wR**0?aT1g0A#>H7UKE{Kd7g*?+PKHujf;0fOdClYG=z4LZPXm0}5lY9ZR zMeEYK3Q2O4;-U@bZaPnhTuPXmlfMePB}BF^HXw(gBfV|Y@AOFWrpblW6C56Tovn)1 z9+Uc(Gj+8NyF*0(cB^zu|LY#09n@%f*91y~LUz_jO8fhd*+$W6ngznfw??`3<5bFx znchkws{ks1H`D!Q*LuSWIBK}x%l)(gXmZ7;kXcR{Vy%T?E?4!!D`l^DDKKfVEkT%L z6{l}OLJvh3h7U&IsejvlF<4j^SrKm#rUg5GvTTiOn?yCQ^h5J7yk6h;O57*#wUV&B z#ZSR8ibCD>jakEOw-Sjsn^7Iy^!x#d#2+vZqje_~KFcwtZORhD=F!J610;OBbFS$m zpYC%%iSc0(l1R-4H_DBvnmtnf(69=%R8c)cPMgXRgLoupnIR8`3PuAWvnCVV?Id|x8%gJdk~JO91pjKIOf$_{3_(* zpmumbdUJ?wH7`LshXsI!2c8*5kOx8MG35Cp5p*Q%eEW4)k?VL#`>u`Z>BeTVNRr7$Gu6A=066Y%8@jaT!FqcHq>JwHYozBuDEAMP&b z+pu?pfPIA2-9pQ9aeW00+u=dVz(wKF*7axLSkiq1KdI=92S~nEu_PG}j8_zf;d#ck z^+mr5uSH>!b=@=|Z1?OwQq?!wk)Bm|4p?7F61qr!&~4;z(y3H95ymU4p8sR5EYb@P z_x5SHU0g3$2y^39^{OEBEP62bUA)3*`mQVQ4=);YAV$xrWtX!k`)dQtdIu7odmA$g z7Q-k6VS@S$K!WJ&c(%gy#ajB)QD6Rb7C%H|tYI9M=%Td$bznTxT0zye&S_j4K_mn8up3U_;jPi7vt#YDq4!PEyh!_TnnfbqWfwqQADr08n?{^vzU<$~bb!n%hDwJjy z*O70t$9_i`l+J}hUDGO)8Hn&;>eZGzzR?Q{w^y&t8JSj`cL!IwKdos?XG!Ptrnpbr zD2}4#9+iAe8{SB3*+K$}5PZM}i+8HHYzfN*8kUmb5I8m>FIF>xx zb;KxKjb?~X;j_~Uw&0(xIc*;?r!aa^E(9mrn4@|l@<>6>Yx3ZnXu;6gQ*3U??Y(%p z+urMZ4ZD-GB-`G_#2&WM;}RBkSg$gX498MpmZoToenEO z6!mf(o}wX#LZ5_0@Y{oHIu5^24b7Jk@&~#Ir^+*+Ka=JTu z2ZbMq-&b7#cKMeJZ7StX-O>pnxgfH8U?T;5)^p*a$_mpFx=-IT%8z(ddfM8KHf(lV zJY{g1{^?Rik6l_=?sqPmZBJAFmGiJOd6vFJ0|(f;8UFB)xl*F>Lzsp1rP{Quo5#Kp z?0qYaTH2F&N*)0V0Sk1F>(L6KU*hVrGp~AZ76*kEsvs2trPyIgr{diCn^$rnA8=D#Y| zV;=K$3~8#PZt607-De{*R+E8F)Lf6Gu@dn-@m9#zJ@`c&v5Bq%O=OEn?Zb(D(t z2?lR-T=r4`z5@T|yu3J_8-8Eduj<=6BW_LMBDk6AGSN1x58C52iQ_=sGh`62N5F1@ zMX&YPgsd`Moh1tU4_zk)xY{ueY#yi)(}9%!WnEf@ZgIXF97g_0J%6}V?LmX=WY~v| z2+3W2+Z39R__ytgo3TBc>R0K}BB01$oIO6`Wy~j8@$qALG*289|FXubTlWdp z{Yq{8{HOcjW<^UjfgCpdFn(n8pA|_EkBwbdm-WWE+?^#Yk z*176jBu&VPQ8`uRRCy1UdgY$b!}-ri{hdp9ak?kmA-N5E$@-|qEhm%_7CSGu1R!Pn zH+LLZOA!6kF7i!Kr{xwFLs4>)$%j6OXU zx?L@3jCAuv@YHvQ%QecNorvq{^yO-@7Jjp@u3d%pOh22^KWWj?gUr3%z2dh-$l>AetT%9 zpo?*9DsfRkIbPUMQ}fx*LJ08yx7yhYX@mW_Fa2r#sAx;2Tl9ZrMd34fWoH=a_|+qQ z-$^S_HrC)csae4VD7Z><5 z3d|YfVxow|pvTQ2w$g1y7?m6DmZ+SYdDIc8G zn7ThF+590YL|r#y%#E>fGmfj0XwHDXW?QtMm=>{wAC6ZcF0l!U|0xv+D0B0%HzneW}YM#m&m(ee#yWJJRms-aiNnqo<+kduzS|>P9xi`_QXrm+!^5EVDE+ z;1<6Bzz|E@q&SsyB|61u#F#+%H2}lvwC%)Zo+nlOJX|15FK)(1tTY04dzxVJMjE+P z&GVddSkS)S-u>tcu4n=Y2+u_8=-jlpB77(2$BYei!@cGKk780>=~9&N6+?Z|2jIl# zQ%d@=O)(Lhb=BJK=`wixD;qOch8PbhHm3)k#TaI{@+zzF)?$I9s)0-20w78X{UPwR zR2f`fC`Os~?hGRvLvi>TT0 zuW-Zd>Uxyt+&x(chW^z<^gf32N&N*Vv&wIMPQXHs*4XX==Z2z-;+K zJ_o!1gqQW87v+Ktc~>irH^;&;jq?yGGt}buuwpH$Bs($|+{aeAe|{1%=7e=7gu(S^ zxKXoa&Q)`C^(~~P04>jH%wJhLw<9&#nHeVFF0>@A>yVv+d3`90Q}c z)e*0>F>~tRIXz|hdzS{V1>f)$2 zj9Olb24VICBdPWCB2NMEgjEAmaOO$h3s5A%*l=$3Sk7rR7@|Gl~I_fo(oB~&oMGqW%!n8XK z!#?Ty+3PPM5tBftNNj8r-e^|VL$?TVAt`MIZ#wx?;j`efD2X56eauFqXS}84{*v;x zHR6V5LWdP)-fQ0UZg%gTzOAp_Il5huG}f{w`CzM)@}~=W2@efN>{#!|&krJ8jIcyL zJ;l%G6jZz^vcufTyE4uotSxr6>r)soz^3gP!*M3tnk>}&P*!cN;ZTWV>Ygz&`Rq_G z!oQma7xS#Tsg^+L)iny%e~O{9amjygAOqebHhclgyr$zA<+J}JGJU(FhF$1s-pcTe z<@&WO)7ziR{ehH$U)O{@Bd9@2dXeqe&~{<1B_t~o{;U$S>6|MK(NBzSxQ%TGz~b2O z`E>lj#9ihGp*$?pKc|Q0>}3$^oagrxEHnn9XMBX_1Ppt#hnyaPfBQ5VVE5=L)}P^T zejT{`VM*BVTb;W|VRYAihBSDGLo2aHKz>y+U4Ru%)gy?OhGh3{CO-<0*ARAXw- zPb^>)t(zIg90j>*dy?`Ed~(S?^;%mlRWA{fg~%)l?0ge8gQ43XfPekUvMw>qJi`&h z(QKV%^THw@(HB?pSB?0Wo`!yE*pn){xPOE8%VF!XOWnyd9KP}~h5Y}1V4|^~ zcSZFqy=arj>zLWP9@AsmEtbL(e(CUyUscNeqn+LFi>v_g-mHiU0VBVQL;Tg9l@*JmnE*}64; zKf7c$kmXV{Q<2QUA$=;@50CB-EPE9xWLSF!dUU?m!TPA>Ll511$|4}ob!}~U2^2Gg zuq%eb>^61?Y_%W%uxD!-6%rfwqbYy1$>-_XsJ>Cmo)Zzy?sbXxGu9)B&e6L9DUAjE z^V=ksEkYfM?$+O(u29G9>7oxJ9-I7u-rhw_IsB)BpHD%9$F9R4s=4-ybBhOrTwbm( zjR{Xv_i+&LW0$&-`q{+{?xJ(ZP=dfe%;v7nSZ$%p5(!TCDZSM5DdCwdEPRiSq1a~R z)+gk-O?y7-)!fS_+taDT<58D)BbFnFhu-cMZ3VujG*S-7;{&oamoI)sf&D}_mOrBS z*UKM}>D-_wWkfY5%~Fohbnow&xK7fS9+0ySg@qZ@&fkmqZq;};!Cq=`7&YB`qiGyGC*BTR#|t$7H!62 z6wM>-ylxygRVz8GB5Qy1vu@Dr^(|ARG+o7us!qxE@`i)9ZQC=^Qu(o~HG+{oj$0SC z%{OdKxMQ(AGgyK_Vy$)ZJ8n1^bUwJ!5sOzy!b9WPO>a!5yUjER#aZaZv->ODrd!tD z{~~r|u%+&LNg{-%1y5(X{%c)?L~c!C-KIWnlbL>k>Lwu7wy0U~D~^8qn7mVqB;>3! z>l;KeBUv|dYHJp2`#H^kv(U5bAP|THs#s9!XM;PS_)VA?i^gEJxPJqik}w?l|CKhJ zchJ7U9i&vGm*sc7?%|zMgEIo?WpA%iAzHFE;Pr6yHT@1n;GH^lvr9{iF3h8l6(ub+ zc51_P{%QN=xPFuxZog4ihbqnh>C3ziZ^eDq zL|M|uC>{9L5hn318MK$x9-Iqc)8CRoV7*RTM0SI==-dl;n>otg$IsZHu@dKCboF1j zPs8wVg3_Y6_io7}p#OhJaKpeRkcTGd|J`}V#&G7GfK#YWH|1OCE}r3*Pi;yXBP^`p zhzp_(^9&(F!4Ky0CAu?3;3AV#m5CNRAY2{9^v6h(hq-U)4W=XOh3n5ob8)etkiN`I zBp&2{gZ*X$1gBVP_nkQl`nzS=PAobTrox${{~?CLYr?GNz7KsDlwYUSUiX72!_J(G zAK6uRm&&(6nBs)bVoFqk(Bs+B>_cIhOsH`aZ7)ck#eX-b0Js^rfWnE7? z^SpM8^c|%&JP;NQ;<;r(PYs}&9NOFL+JYoCM39xJ+_a-}@NusHW`Ns976hb8>>|Ag zOpv5#>W$@)-w!-d3_9qU*0YVRj6eHvDC~{166(;)!H!$*Tm3Ek zKmE@zy9OR0e{O(BUs-X2{?u9VAYCvtY9xI#Dd-d$$X(% z2%A31pBlXFnggE?no}Yz?$AnB+14gohYj`L$@GJ8O*VSJ&^`phdX8Av`@EXBM=TAz zBXCI8jt&;!i=t3KjKWE)tGy<8!DXWme4nxx@a@FN=pY%g1YSPYJ+$+$hD0|KM<~&Z z+>3CV2_+*D#~3J8_Ac*x?$!=xs(4O6L3^cqZyBzvTa{nma92{FyxvT|06fsoC7Aqv zgPKk~d3z1d3T#q?t1*R))gr@i+!a(xHL^l^9C-5dD~qw;upF?IcMz=y+L1H4cZJdJ zH(UTP{JBf_kPQ#x(JH&*QNxl<-?2v{gU)a&VT<=R&)2uO<;W}B|Y#b+v{BMhRtOn?Q#2irSW3DCEAD2@40Sd0HvwHpVtN7xIkMyi!)3Lr> zUJme_hu(b=8)~PVNqIifdC-8o229WIsbm8x)}P(vpOUq-Y_vAiU+frvpbp?#k2%ep zGtl)>Bp<%?)Y4qRrOM8jmX;PC+4PI2TNucra%6jSd10c3eGH2#kjaf9@gdtd`Shj+ zVo?+=XI#k=Mo#}j!#O9JD{gtTAQxY?vfFhBJ*JR-C&Zi%H0;Tg#qM!!hq1pAD6?#H zEsf09@Cnuk7FTxNkRcgD3Rz>aZk|SyaYr)0%mr@!`k16B^=)nYbZgW*;@)h02JQh6 zb4sfw90(Qy5^5e3sut9;x7Nku0C+4b-B0;n)~MosCeupY|AZ1&N4Ve>mw778z9f`{ zvvogfsM>S9ubNC}z(sVDZy78eI1oQC-(>km3z`b6g%%4I=`XqV*P)qq#+yM=##Zmq zH3pZ!sAh?-KjeW>sD+JFUVUM>Dde5>96jO*Jgp4#6u9Fn0I zk>%GX*_@<8>df_bfuE*7f(6$8w*S)E0ibR1Z_B+szorVaTt&hx0@h36VzSw=jKdQA zcIz4N+3jj4-}1;~Cj`A;{J>l(_qAD35Zmvdr#d!~BLbbVfS>;2@%f*)W3KfgB zYIX&>3YpJR3Zq(g!WsI5_1_VSaQAf_g2_pw8s()|j}6b@@g9cB?DI;o_<<7I%kOxa zR$%OKQnnsBU%@XfI0`wZ4sYl3b(l5>`pU274s^GxKY~6t2tf!sLd)rlGOi55u_uHs zZ$1#p+E2nR4+Xa2#7B1!-4mtZ0!i~qM{4O0k^+%xMJD~(M3dfrPtn5%2iyAc1nL4? zoZUcm#uFLLX9^yjAE zhrA8~FOs&F(WnD49r5y23p&Dt;td5Mf0vt{6}pLt_0Z~H!FdxctR#=CXBsHxhMaZ- z@{Jh->cX$S83SQ(-Po6;I&4XeMfJzEPuzuTz%eIuVTZt_^)p^N(TkjvQw5xBetwDi zTw8&5vj)0>Dlkd0h_?g8_WI$XP#7`woI46C?R7U)C))|TRwu2W5g#xsju6^vMV%!q zVg_;ks%d2dWc}U@PHrH{GEKl%@|0&jWRdAfzl%73N72tyz}fo?xr$l}^#*@w zaE9+rclY7lD{He;*m|t?Bh4(lW0LFTiToA2-N&R@k=EGm zH_ZF@@tPXmpN^uE=HXiB^ED}AYt6rR$bM$!j(a3B4XVG&r>`%9)^}wK53VT2$`Ntd zxg1c5VQbh5bmKorN)c5vqA!?udj@Y>j=JDI5Z%6rH^AU~%Z2}P;R3WAaAgLq5mfsD zTYXsiPNis+cs;h77ETkiX)Qf?)Q+M607N&?Ev!`fiT(#<*hsISk1oCRG=j~a)RYOI z%J0c2FQ{vv$tdm{)`!FUZ>`c*yYK@!V-}VM@9MnoQk$-3h*M#sT>z)g5YXFx>`>G2 zke0X;iQy1^rGn;Yb})f8d#?jz*&!jQ>NTEH|8^=u$}(lvti^ zn(~{6qF{(f(uple^5DaKrw6w|brCzzSagCBDskBH+%~rzkj{QlzF>Q~aS^`tnofi@ zihPX~aip$~EZ5JXrQka;H*ZK1kN?u#_bHGe3>YLaIR(bMcQz{@066PULr2#*+lZ%| zQCr0CRK@{8hFJd2*HsODq;72K|K25lxPjHBjWc7*_4W1NJ3H5#dJ3^=fgZ&{%7cBe zq|w?ZtKN^lN{YC&5$HIwWJm%LUAF#l{AfN}`nN)w5zaTOQd>j@oPd%owD-kx=tDU% z?JKePWCl2eO{yk##LJu;Mo4ob?Gkgyb*R0~L&-y23FC%P4Kmo%1D`kM>hMI&C}eD? z?cifh!q4ipQ@>y7GFzYNU8tLh@p)iR2fWBpVy_nlx7ESAIFj_=1T7vXTbycbB8|e= zgb=_G4!x&2ndLlE4~Xc+4>t#{msN0vC=jUt)R2k5`+0FRh2^>@1|wFx-oa0}nZ-Gf z*+f(Kc3`KZSRz`^b`+cAU=9wYZ)y}=pL~#Xv^Xwv$!omq@BPUd&r4nHp*B2LTF0%EOycmP2joye6oHqsKV%G^=#-i8oieMi z)EASlKK8eD^(6t7w2!w%LiBu3GMeXond^>3SAl_GzSHr9x_$s&$$dvtykq$ZnGTCp{VkD$S*j-!68E>headBC z`tqeAyH==7s~b@sFFhT(uVU?zKd^VzjUeH!%wst>j8UOw;E_FR$AjQPvo$f{Rpp>e zYF7!*gpb4^#tm%*AZb!C{}vT5ms|5t;?=uBXQ9G7N8U3Q)i8F@A{Dp9#cMuCNoW*L z?(ide&KK=!?*kSwM$ngx!Z7dUYIHK+OF)X}f)h5_;j$J?IB|1RT0;WmHrG2Z^m2d! zPl4d69OEqqaPXt>(^8x`*)dvAUT8B7^9dPlU@-%X`-JxPRdPV15-w=Le7I4G3s;sD z_e=qc@yt1KkGuT6V_Ytq93yGOiAwEFX*pG#YNB<- z$yN8wb=7(ZxQtUdca}LM+^b??3h1Q$PU#YLOb_aQP!GN(Uwao=$S1dvPBK6hsMjyh zd3fjg%%~v{wOQtC_Vx0JH>$IwHN4qG-DJXlzPUbPQuI58FV^SWu8V~A%QwmFJalJ3 zp*gpNdu7>}HETKo!*@)Opxj6J^Pl~~8t}ttWj%8z_%{4xhYYx$LSXYhOaVyLg$2DA zQQwr0vd??LoGTLUa@-~i@wo{z*86E}sF2AaXe(29d^>ktzO8}775$S9 zwjQVZO+m0AZ1J>c-s)TABP0rQY>1JrQNay6L?t?bx|cUAT8PmEOhQTafp#xJS)?-1 zo4Cezs$JRtO9sp9<_>7sXEyWw+7Xd`zBkh=0Vcz^zhfFMT^mneSId)JkD>WFf{JZ} zIr7G>n7!rmsw}UTiG!Yn^3QnNY5U6gy6I|MVlI1h*q(+L^#0>OS-zt*H`Wz9;NO?o z^$$TUE$|E4Ov^V67o<1Z{|GKiamfKmz9RV8ljhf*ayiOq74WVyacNXGf!50sWGi#; zQhd5K<|oxFopZ^r?z7;tFcuJN?s!M}=8>lU(>kCW7(3juNhpgxCohpTpU}EFFkp~W zR{lwvv(=;#0t&HR6}0V%RMb4sf4JN^Qnwws-n_@%%3#f+=L>((|J@m0Pn1uhi#e(> zgQeg|=)LuY$xe!}6kiV~Ylnd%?mZ#r;wF=JMjVzmmKz}HhM#FMc6osl(6$Tto5|&b zT>N`hoi%OeQu}b;Fpk)L1RCRUid6we_Q;P{wDWmkYD{l|rhO@_%4~)=(E`s*mp